diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7a89c01fee73..daee09de8826 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -44,7 +44,9 @@ exportMethods("glm", "spark.gaussianMixture", "spark.als", "spark.kstest", - "spark.logit") + "spark.logit", + "spark.randomForest", + "spark.gbt") # Job group lifecycle management methods export("setJobGroup", @@ -350,7 +352,11 @@ export("as.DataFrame", "uncacheTable", "print.summary.GeneralizedLinearRegressionModel", "read.ml", - "print.summary.KSTest") + "print.summary.KSTest", + "print.summary.RandomForestRegressionModel", + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -375,6 +381,10 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.RandomForestRegressionModel) +S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1df8bbf9fe60..1cf9b38ea648 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -788,7 +788,7 @@ setMethod("write.json", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "json", path)) + invisible(handledCallJMethod(write, "json", path)) }) #' Save the contents of SparkDataFrame as an ORC file, preserving the schema. @@ -819,7 +819,7 @@ setMethod("write.orc", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "orc", path)) + invisible(handledCallJMethod(write, "orc", path)) }) #' Save the contents of SparkDataFrame as a Parquet file, preserving the schema. @@ -851,7 +851,7 @@ setMethod("write.parquet", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "parquet", path)) + invisible(handledCallJMethod(write, "parquet", path)) }) #' @rdname write.parquet @@ -895,7 +895,7 @@ setMethod("write.text", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "text", path)) + invisible(handledCallJMethod(write, "text", path)) }) #' Distinct @@ -3342,7 +3342,7 @@ setMethod("write.jdbc", jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) #' randomSplit diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 216ca51666ba..38d83c6e5c52 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -350,7 +350,7 @@ read.json.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "json", paths) + sdf <- handledCallJMethod(read, "json", paths) dataFrame(sdf) } @@ -422,7 +422,7 @@ read.orc <- function(path, ...) { path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "orc", path) + sdf <- handledCallJMethod(read, "orc", path) dataFrame(sdf) } @@ -444,7 +444,7 @@ read.parquet.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "parquet", paths) + sdf <- handledCallJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -496,7 +496,7 @@ read.text.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "text", paths) + sdf <- handledCallJMethod(read, "text", paths) dataFrame(sdf) } @@ -914,12 +914,13 @@ read.jdbc <- function(url, tableName, } else { numPartitions <- numToInt(numPartitions) } - sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), - numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) } else if (length(predicates) > 0) { - sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), + jprops) } else { - sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, jprops) } dataFrame(sdf) } diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 03e70bb2cb82..0a789e6c379d 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) { conn <- get(".sparkRCon", .sparkREnv) writeBin(requestMessage, conn) - # TODO: check the status code to output error information returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + + # Backend will send +1 as keep alive value to prevent various connection timeouts + # on very long running jobs. See spark.r.heartBeatInterval + while (returnStatus == 1) { + returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + } + + readObject(conn) +} + +# Helper function to check for returned errors and print appropriate error message to user +handleErrors <- function(returnStatus, conn) { if (length(returnStatus) == 0) { stop("No status is returned. Java SparkR backend might have failed.") } - if (returnStatus != 0) { + + # 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors. + if (returnStatus < 0) { stop(readString(conn)) } - readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 2d341d836c13..9d82814211bc 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout = 6000) { +connectBackend <- function(hostname, port, timeout) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4d94b4cd05d4..f8a9d3ce5d91 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1485,7 +1485,7 @@ setMethod("soundex", #' Return the partition ID as a column #' -#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Return the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. #' @@ -2317,7 +2317,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_utc_timestamp #' -#' Assumes given timestamp is UTC and converts to given timezone. +#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp +#' that corresponds to the same time of day in the given timezone. #' #' @param y Column to compute on. #' @param x time zone to use. @@ -2340,7 +2341,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2391,7 +2392,8 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' to_utc_timestamp #' -#' Assumes given timestamp is in given timezone and converts to UTC. +#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns +#' another timestamp that corresponds to the same time of day in UTC. #' #' @param y Column to compute on #' @param x timezone to use @@ -2539,7 +2541,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the given value numBits right. If the given value is a long value, it will return +#' (Signed) shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @param y column to compute on. @@ -2777,7 +2779,7 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. @@ -2823,7 +2825,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' rand #' -#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -2852,7 +2855,8 @@ setMethod("rand", signature(seed = "numeric"), #' randn #' -#' Generate a column with i.i.d. samples from the standard normal distribution. +#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' the standard normal distribution. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -3442,8 +3446,8 @@ setMethod("size", #' sort_array #' -#' Sorts the input array for the given column in ascending order, -#' according to the natural ordering of the array elements. +#' Sorts the input array in ascending or descending order according +#' to the natural ordering of the array elements. #' #' @param x A Column to sort #' @param asc A logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 107e1c638be7..7653ca7bccec 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) -#' @rdname spark.glm +###################### Spark.ML Methods ########################## + +#' @rdname fitted #' @export -setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +setGeneric("fitted") #' @param x,y For \code{glm}: logical values indicating whether the response vector #' and model matrix used in the fitting process should be returned as @@ -1332,13 +1334,42 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @export setGeneric("rbind", signature = "...") +#' @rdname spark.als +#' @export +setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) + +#' @rdname spark.gaussianMixture +#' @export +setGeneric("spark.gaussianMixture", + function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) + +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + +#' @rdname spark.glm +#' @export +setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) + +#' @rdname spark.isoreg +#' @export +setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) + #' @rdname spark.kmeans #' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) -#' @rdname fitted +#' @rdname spark.kstest #' @export -setGeneric("fitted") +setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) + +#' @rdname spark.logit +#' @export +setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp #' @export @@ -1348,13 +1379,14 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) -#' @rdname spark.survreg +#' @rdname spark.randomForest #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.randomForest", + function(data, formula, ...) { standardGeneric("spark.randomForest") }) -#' @rdname spark.lda +#' @rdname spark.survreg #' @export -setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) #' @rdname spark.lda #' @export @@ -1364,20 +1396,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) -#' @rdname spark.isoreg -#' @export -setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) - -#' @rdname spark.gaussianMixture -#' @export -setGeneric("spark.gaussianMixture", - function(data, formula, ...) { - standardGeneric("spark.gaussianMixture") - }) - -#' @rdname spark.logit -#' @export -setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1385,11 +1403,3 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) - -#' @rdname spark.als -#' @export -setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) - -#' @rdname spark.kstest -#' @export -setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 629f284b79f3..1065b4b37d7f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -102,6 +102,34 @@ setClass("KSTest", representation(jobj = "jobj")) #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -110,9 +138,10 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.survreg} +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} NULL @@ -124,8 +153,10 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -619,7 +650,7 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @return \code{summary} returns the model's features, coefficients, k, size and cluster. #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -664,15 +695,15 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param data SparkDataFrame for training #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. +#' @param regParam the regularization parameter. #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination #' of L1 and L2. Default is 0.0 which is an L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. +#' @param fitIntercept whether to fit an intercept term. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". +#' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". @@ -690,11 +721,11 @@ setMethod("predict", signature(object = "KMeansModel"), #' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. #' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". +#' are large, this param could be adjusted to a larger size. +#' @param probabilityCol column name for predicted class conditional probabilities. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model #' @rdname spark.logit @@ -776,8 +807,10 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel #' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. +#' @return \code{summary} returns the Binary Logistic regression results of a given model as list, +#' including roc, areaUnderROC, pr, fMeasureByThreshold, precisionByThreshold, +#' recallByThreshold, totalIterations, objectiveHistory. Note that Multinomial logistic +#' regression summary is not available now. #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -1122,6 +1155,14 @@ read.ml <- function(path) { new("ALSModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1177,13 +1218,13 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' #' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". #' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' each iteration of mini-batch gradient descent, in range (0, 1]. #' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for #' the prior placed on topic distributions over terms, default -1 to set automatically on the #' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size @@ -1244,7 +1285,7 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, +#' @return \code{summary} returns a list containing the model's features, coefficients, #' intercept and log(scale) #' @rdname spark.survreg #' @export @@ -1332,7 +1373,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @return \code{summary} returns the model's lambda, mu, sigma, k, dim and posterior. #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1617,3 +1658,451 @@ print.summary.KSTest <- function(x, ...) { cat(summaryStr, "\n") invisible(x) } + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' df <- createDataFrame(iris) +#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), as.character(probabilityCol), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(RandomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(RandomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Get the summary of a Random Forest Classification Model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +# Prints the summary of Random Forest Regression Model + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc6d591bb2f4..6b4a2f2fdc85 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -154,6 +154,7 @@ sparkR.sparkContext <- function( packages <- processSparkPackages(sparkPackages) existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) if (existingPort != "") { if (length(packages) != 0) { warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell", @@ -187,6 +188,7 @@ sparkR.sparkContext <- function( backendPort <- readInt(f) monitorPort <- readInt(f) rLibPath <- readString(f) + connectionTimeout <- readInt(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || @@ -194,7 +196,9 @@ sparkR.sparkContext <- function( length(rLibPath) != 1) { stop("JVM failed to launch") } - assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".monitorConn", + socketConnection(port = monitorPort, timeout = connectionTimeout), + envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -204,7 +208,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort) + connectBackend("localhost", backendPort, timeout = connectionTimeout) }, error = function(err) { stop("Failed to connect JVM\n") diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c4e78cbb804d..20004549cc03 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -338,21 +338,41 @@ varargsToEnv <- function(...) { # into string. varargsToStrEnv <- function(...) { pairs <- list(...) + nameList <- names(pairs) env <- new.env() - for (name in names(pairs)) { - value <- pairs[[name]] - if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { - stop(paste0("Unsupported type for ", name, " : ", class(value), - ". Supported types are logical, numeric, character and NULL.")) - } - if (is.logical(value)) { - env[[name]] <- tolower(as.character(value)) - } else if (is.null(value)) { - env[[name]] <- value - } else { - env[[name]] <- as.character(value) + ignoredNames <- list() + + if (is.null(nameList)) { + # When all arguments are not named, names(..) returns NULL. + ignoredNames <- pairs + } else { + for (i in seq_along(pairs)) { + name <- nameList[i] + value <- pairs[i] + if (identical(name, "")) { + # When some of arguments are not named, name is "". + ignoredNames <- append(ignoredNames, value) + } else { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL."), call. = FALSE) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } } } + + if (length(ignoredNames) != 0) { + warning(paste0("Unnamed arguments ignored: ", paste(ignoredNames, collapse = ", "), "."), + call. = FALSE) + } env } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 6d1fccc7c058..33e9d0d267ac 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -64,6 +64,16 @@ test_that("spark.glm and predict", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # binomial family + binomialTraining <- training[training$Species %in% c("versicolor", "virginica"), ] + model <- spark.glm(binomialTraining, Species ~ Sepal_Length + Sepal_Width, + family = binomial(link = "logit")) + prediction <- predict(model, binomialTraining) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica", + "versicolor", "virginica", "versicolor", "virginica", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) + # poisson family model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, family = poisson(link = identity)) @@ -128,12 +138,12 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # Test spark.glm works with weighted dataset - a1 <- c(0, 1, 2, 3) - a2 <- c(5, 2, 1, 3) - w <- c(1, 2, 3, 4) - b <- c(1, 0, 1, 0) + a1 <- c(0, 1, 2, 3, 4) + a2 <- c(5, 2, 1, 3, 2) + w <- c(1, 2, 3, 4, 5) + b <- c(1, 0, 1, 0, 0) data <- as.data.frame(cbind(a1, a2, w, b)) - df <- suppressWarnings(createDataFrame(data)) + df <- createDataFrame(data) stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) @@ -158,7 +168,7 @@ test_that("spark.glm summary", { data <- as.data.frame(cbind(a1, a2, b)) df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) - expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result }) test_that("spark.glm save/load", { @@ -575,7 +585,7 @@ test_that("spark.isotonicRegression", { feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) data <- as.data.frame(cbind(label, feature, weight)) - df <- suppressWarnings(createDataFrame(data)) + df <- createDataFrame(data) model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, weightCol = "weight") @@ -871,4 +881,140 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) +test_that("spark.randomForest Regression", { + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, + 63.736, 64.296, 64.868, 64.300, + 66.709, 67.697, 67.966, 67.252, + 68.866, 69.593, 69.195, 69.658), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) +}) + +test_that("spark.randomForest Classification", { + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) +}) + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) +}) + sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9289db57b6d6..ee48baa59c7a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() @@ -2659,7 +2659,15 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # It makes sure that we can omit path argument in write.df API and then it calls # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), - "Error in save : illegal argument - 'path' is not specified") + "Error in save : illegal argument - Expected exactly one path to be specified") + expect_error(write.json(df, jsonPath), + "Error in json : analysis error - path file:.*already exists") + expect_error(write.text(df, jsonPath), + "Error in text : analysis error - path file:.*already exists") + expect_error(write.orc(df, jsonPath), + "Error in orc : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath), + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2679,6 +2687,11 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") + expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") + expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") + expect_error(read.parquet("arbitrary_path"), + "Error in parquet : analysis error - Path does not exist") # Arguments checking in R side. expect_error(read.df(path = c(3)), @@ -2686,6 +2699,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) + + expect_warning(read.json(jsonPath, a = 1, 2, 3, "a"), + "Unnamed arguments ignored: 2, 3, a.") }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index a20254e9b3fa..607c407f04f9 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -224,6 +224,8 @@ test_that("varargsToStrEnv", { expect_error(varargsToStrEnv(a = list(1, "a")), paste0("Unsupported type for a : list. Supported types are logical, ", "numeric, character and NULL.")) + expect_warning(varargsToStrEnv(a = 1, 2, 3, 4), "Unnamed arguments ignored: 2, 3, 4.") + expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) sparkR.session.stop() diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index b92e6be995ca..3a318b71ea06 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,6 +18,7 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") @@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) +inputCon <- socketConnection( + port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) while (TRUE) { ready <- socketSelect(list(inputCon)) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index cfe41ded200c..03e745014786 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -90,6 +90,7 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]] suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") -outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") +inputCon <- socketConnection( + port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) +outputCon <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 1fd6ef4a7125..42e2d9abdeb5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -68,16 +68,16 @@ {{#applications}} - {{id}} + {{id}} {{name}} {{#attempts}} - {{attemptId}} + {{attemptId}} {{startTime}} {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2a32e18672a2..6c0ec8d5fce5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -119,7 +119,11 @@ $(document).ready(function() { } } - var data = {"applications": array} + var data = { + "uiroot": uiRoot, + "applications": array + } + $.get("static/historypage-template.html", function(template) { historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); var selector = "#history-summary-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860e..0315ebf5c48a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e37307aa1f70..0fa1fcf25f8b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var uiRoot = ""; + +function setUIRoot(val) { + uiRoot = val; +} + function collapseTablePageLoad(name, table){ if (window.localStorage.getItem(name) == "true") { // Set it to false so that the click function can revert it diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd..25a3d609a6b0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -183,6 +183,8 @@ class SparkContext(config: SparkConf) extends Logging { // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") + warnDeprecatedVersions() + /* ------------------------------------------------------------------------------------- * | Private variables. These variables keep the internal state of the context, and are | | not accessible by the outside world. They're mutable since we want to initialize all | @@ -346,6 +348,16 @@ class SparkContext(config: SparkConf) extends Logging { value } + private def warnDeprecatedVersions(): Unit = { + val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3) + if (javaVersion.length >= 2 && javaVersion(1).toInt == 7) { + logWarning("Support for Java 7 is deprecated as of Spark 2.0.0") + } + if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) { + logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0") + } + } + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN @@ -1716,29 +1728,12 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null - } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null - } + try { + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null } // A JAR file which exists locally on every worker node case "local" => @@ -1762,8 +1757,26 @@ class SparkContext(config: SparkConf) extends Logging { */ def listJars(): Seq[String] = addedJars.keySet.toSeq - // Shut down the SparkContext. - def stop() { + /** + * Shut down the SparkContext. + */ + def stop(): Unit = { + if (env.rpcEnv.isInRPCThread) { + // `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread. + // We should launch a new thread to call `stop` to avoid dead-lock. + new Thread("stop-spark-context") { + setDaemon(true) + + override def run(): Unit = { + _stop() + } + }.start() + } else { + _stop() + } + } + + private def _stop() { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc86..46e22b215b8e 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,14 +20,14 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.SparkHadoopWriterUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD import org.apache.spark.util.SerializableJobConf @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -153,29 +153,8 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { splitID = splitid attemptID = attemptid - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) + jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) taID = new SerializableWritable[TaskAttemptID]( new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } } - -private[spark] -object SparkHadoopWriter { - def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") - val jobtrackerID = formatter.format(time) - new JobID(jobtrackerID, id) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } -} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 41d0a85ee3ad..550746c552d0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -43,7 +44,10 @@ private[spark] class RBackend { def init(): Int = { val conf = new SparkConf() - bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + bossGroup = new NioEventLoopGroup( + conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) @@ -63,6 +67,7 @@ private[spark] class RBackend { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) + .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) .addLast("handler", handler) } }) @@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging { val boundPort = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() + // Connection timeout is set by socket client. To make it configurable we will pass the + // timeout value to client inside the temp file + val conf = new SparkConf() + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) // tell the R process via temporary file val path = args(0) @@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(boundPort) dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) + dos.writeInt(backendConnectionTimeout) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 1422ef888fd4..9f5afa29d6d2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -18,16 +18,19 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.timeout.ReadTimeoutException import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} /** * Handler for RBackend @@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend) writeString(dos, s"Error: unknown method $methodName") } } else { + // To avoid timeouts when reading results in SparkR driver, we will be regularly sending + // heartbeat responses. We use special code +1 to signal the client that backend is + // alive and it should continue blocking for result. + val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread") + val pingRunner = new Runnable { + override def run(): Unit = { + val pingBaos = new ByteArrayOutputStream() + val pingDaos = new DataOutputStream(pingBaos) + writeInt(pingDaos, +1) + ctx.write(pingBaos.toByteArray) + } + } + val conf = new SparkConf() + val heartBeatInterval = conf.getInt( + "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) + + execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS) handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + execService.shutdown() + execService.awaitTermination(1, TimeUnit.SECONDS) } val reply = bos.toByteArray @@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend) } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Close the connection when an exception is raised. - cause.printStackTrace() - ctx.close() + cause match { + case timeout: ReadTimeoutException => + // Do nothing. We don't want to timeout on read + logWarning("Ignoring read timeout in RBackendHandler") + case _ => + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } } def handleMethodCall( diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 496fdf851f7d..7ef64723d959 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -333,6 +333,8 @@ private[r] object RRunner { var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") rCommand = sparkConf.get("spark.r.command", rCommand) + val rConnectionTimeout = sparkConf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir(0) + "/SparkR/worker/" + script @@ -344,6 +346,7 @@ private[r] object RRunner { pb.environment().put("R_TESTS", "") pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 77825e75e513..fdd8cf62f0e5 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -84,7 +84,6 @@ private[spark] object RUtils { } } else { // Otherwise, assume the package is local - // TODO: support this for Mesos val sparkRPkgPath = localSparkRPackagePath.getOrElse { throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") } diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala new file mode 100644 index 000000000000..af67cbbce4e5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +private[spark] object SparkRDefaults { + + // Default value for spark.r.backendConnectionTimeout config + val DEFAULT_CONNECTION_TIMEOUT: Int = 6000 + + // Default value for spark.r.heartBeatInterval config + val DEFAULT_HEARTBEAT_INTERVAL: Int = 100 + + // Default value for spark.r.numRBackendThreads config + val DEFAULT_NUM_RBACKEND_THREADS = 2 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index d0466830b217..6eb53a825220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkUserAppException} -import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults} import org.apache.spark.util.RedirectThread /** @@ -51,6 +51,10 @@ object RRunner { cmd } + // Connection timeout set by R process on its connection to RBackend in seconds. + val backendConnectionTimeout = sys.props.getOrElse( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString) + // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode val rF = new File(rFile) @@ -81,6 +85,7 @@ object RRunner { val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) // Put the R package directories into an env variable of comma-separated paths env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac3..23156072c3eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5c052286099f..c70061bc5b5b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -322,7 +322,7 @@ object SparkSubmit { } // Require all R files to be local - if (args.isR && !isYarnCluster) { + if (args.isR && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } @@ -330,9 +330,6 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + - "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -410,9 +407,9 @@ object SparkSubmit { printErrorAndExit("Distributing R packages with standalone cluster is not supported.") } - // TODO: Support SparkR with mesos cluster - if (args.isR && clusterManager == MESOS) { - printErrorAndExit("SparkR is not supported for Mesos cluster.") + // TODO: Support distributing R packages with mesos cluster + if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -598,6 +595,9 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + } else if (args.isR) { + // Second argument is main class + childArgs += (args.primaryResource, "") } else { childArgs += (args.primaryResource, args.mainClass) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c..4618e6117a4f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a96..8b1c6bf2e5fd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f66510b6f977..59404e08895a 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,6 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} +import org.apache.spark.internal.config +import org.apache.spark.SparkContext + /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -40,9 +43,14 @@ private[spark] abstract class StreamFileInputFormat[T] * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API * which is set through setMaxSplitSize */ - def setMinPartitions(context: JobContext, minPartitions: Int) { - val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum - val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong + def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { + val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultParallelism = sc.defaultParallelism + val files = listStatus(context).asScala + val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum + val bytesPerCore = totalBytes / defaultParallelism + val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 497ca92c7bc6..4a3e3d5c79ef 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -206,4 +206,17 @@ package object config { "encountering corrupt files and contents that have been read will still be returned.") .booleanConf .createWithDefault(false) + + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) + + private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes") + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimate, then the partitions with small files will be faster than partitions with" + + " bigger files.") + .longConf + .createWithDefault(4 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala new file mode 100644 index 000000000000..fb8020585cf8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.util.Utils + + +/** + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. Implementations should have a constructor with either 2 or 3 arguments: + * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 3. A committer should not be reused across multiple Spark jobs. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +object FileCommitProtocol { + class TaskCommitMessage(val obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(null) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + : FileCommitProtocol = { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala new file mode 100644 index 000000000000..6b0bcb8f908b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.util.Date + +import org.apache.hadoop.conf.Configurable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapReduceCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + val format = context.getOutputFormatClass.newInstance() + // If OutputFormat is Configurable, we should set conf to it. + format match { + case c: Configurable => c.setConf(context.getConfiguration) + case _ => () + } + format.getOutputCommitter(context) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val filename = f"part-$split%05d-$jobId$ext" + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + committer = setupCommitter(taskAttemptContext) + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + committer = setupCommitter(taskContext) + committer.setupTask(taskContext) + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + EmptyTaskCommitMessage + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + } + + /** Whether we are using a direct output committer */ + def isDirectOutput(): Boolean = committer.getClass.getSimpleName.contains("Direct") +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala new file mode 100644 index 000000000000..796439276a22 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.reflect.ClassTag +import scala.util.DynamicVariable + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.{SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * A helper object that saves an RDD using a Hadoop OutputFormat + * (from the newer mapreduce API, not the old mapred API). + */ +private[spark] +object SparkHadoopMapReduceWriter extends Logging { + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + hadoopConf: Configuration): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + val sparkConf = rdd.conf + val conf = new SerializableConfiguration(hadoopConf) + + // Set up a job. + val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) + val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) + val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) + val format = jobContext.getOutputFormatClass + + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { + // FileOutputFormat ignores the filesystem parameter + val jobFormat = format.newInstance + jobFormat.checkOutputSpecs(jobContext) + } + + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = stageId.toString, + outputPath = conf.value.get("mapred.output.dir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] + committer.setupJob(jobContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + if (SparkHadoopWriterUtils.isSpeculationEnabled(sparkConf) && committer.isDirectOutput) { + val warningMessage = + s"$committer may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use an output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + hadoopConf = conf.value, + outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write a RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + hadoopConf: Configuration, + outputFormat: Class[_ <: OutputFormat[K, V]], + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, + sparkPartitionId, sparkAttemptNumber) + val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) + committer.setupTask(taskContext) + + val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = + SparkHadoopWriterUtils.initHadoopOutputMetrics(context) + + // Initiate the writer. + val taskFormat = outputFormat.newInstance() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(hadoopConf) + case _ => () + } + val writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[RecordWriter[K, V]] + require(writer != null, "Unable to obtain RecordWriter") + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val pair = iterator.next() + writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + SparkHadoopWriterUtils.maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback, recordsWritten) + recordsWritten += 1 + } + + committer.commitTask(taskContext) + }(catchBlock = { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + }, finallyBlock = writer.close(taskContext)) + + outputMetricsAndBytesWrittenCallback.foreach { + case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} + +private[spark] +object SparkHadoopWriterUtils { + + private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + def createJobID(time: Date, id: Int): JobID = { + val jobtrackerID = createJobTrackerID(time) + new JobID(jobtrackerID, id) + } + + def createJobTrackerID(time: Date): String = { + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { + val validationDisabled = disableOutputSpecValidation.value + val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } + + def isSpeculationEnabled(conf: SparkConf): Boolean = { + conf.getBoolean("spark.speculation", false) + } + + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + // return type: (output metrics, bytes written callback), defined only if the latter is defined + def initHadoopOutputMetrics( + context: TaskContext): Option[(OutputMetrics, () => Long)] = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() + bytesWrittenCallback.map { b => + (context.taskMetrics().outputMetrics, b) + } + } + + def maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], + recordsWritten: Long): Unit = { + if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + outputMetricsAndBytesWrittenCallback.foreach { + case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } + } + } + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index b54885b7ff8b..3f7cfd9d2c11 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -76,7 +76,7 @@ object HiveCatalogMetrics extends Source { val METRIC_PARTITIONS_FETCHED = metricRegistry.counter(MetricRegistry.name("partitionsFetched")) /** - * Tracks the total number of files discovered off of the filesystem by ListingFileCatalog. + * Tracks the total number of files discovered off of the filesystem by InMemoryFileIndex. */ val METRIC_FILES_DISCOVERED = metricRegistry.counter(MetricRegistry.name("filesDiscovered")) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 41832e835474..50d977a92da5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat private[spark] class BinaryFileRDD[T]( - sc: SparkContext, + @transient private val sc: SparkContext, inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], @@ -43,7 +43,7 @@ private[spark] class BinaryFileRDD[T]( case _ => } val jobContext = new JobContextImpl(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) + inputFormat.setMinPartitions(sc, jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de09..36a2f5c87e37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -243,7 +243,8 @@ class HadoopRDD[K, V]( var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb65887..488e777fea37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad74..f9b9631d9e7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -18,33 +18,31 @@ package org.apache.spark.rdd import java.nio.ByteBuffer -import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{HashMap => JHashMap} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import scala.util.DynamicVariable import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType} -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.OutputMetrics +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapReduceCommitProtocol, SparkHadoopMapReduceWriter, SparkHadoopWriterUtils} import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1060,7 +1058,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } FileOutputFormat.setOutputPath(hadoopConf, - SparkHadoopWriter.createPathFromString(path, hadoopConf)) + SparkHadoopWriterUtils.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) } @@ -1076,80 +1074,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - val jobConfiguration = job.getConfiguration - val wrappedConf = new SerializableConfiguration(jobConfiguration) - val outfmt = job.getOutputFormatClass - val jobFormat = outfmt.newInstance - - if (isOutputSpecValidationEnabled) { - // FileOutputFormat ignores the filesystem parameter - jobFormat.checkOutputSpecs(job) - } - - val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value - /* "reduce task" */ - val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, - context.attemptNumber) - val hadoopContext = new TaskAttemptContextImpl(config, attemptId) - val format = outfmt.newInstance - format match { - case c: Configurable => c.setConf(config) - case _ => () - } - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - - val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = - initHadoopOutputMetrics(context) - - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val pair = iter.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close(hadoopContext)) - committer.commitTask(hadoopContext) - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } - 1 - } : Int - - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) - val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - - // When speculation is on and output committer class name contains "Direct", we should warn - // users that they may loss data if they are using a direct output committer. - val speculationEnabled = self.conf.getBoolean("spark.speculation", false) - val outputCommitterClass = jobCommitter.getClass.getSimpleName - if (speculationEnabled && outputCommitterClass.contains("Direct")) { - val warningMessage = - s"$outputCommitterClass may be an output committer that writes data directly to " + - "the final location. Because speculation is enabled, this output committer may " + - "cause data loss (see the case in SPARK-10063). If possible, please use an output " + - "committer that does not have this behavior (e.g. FileOutputCommitter)." - logWarning(warningMessage) - } - - jobCommitter.setupJob(jobTaskContext) - self.context.runJob(self, writeShard) - jobCommitter.commitJob(jobTaskContext) + SparkHadoopMapReduceWriter.write( + rdd = self, + hadoopConf = conf) } /** @@ -1178,7 +1105,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (isOutputSpecValidationEnabled) { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(hadoopConf) hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) @@ -1193,7 +1120,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = - initHadoopOutputMetrics(context) + SparkHadoopWriterUtils.initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1205,7 +1132,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) // Update bytes written metric every few records - maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) + SparkHadoopWriterUtils.maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } }(finallyBlock = writer.close()) @@ -1220,29 +1148,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.commitJob() } - // TODO: these don't seem like the right abstractions. - // We should abstract the duplicate code in a less awkward way. - - // return type: (output metrics, bytes written callback), defined only if the latter is defined - private def initHadoopOutputMetrics( - context: TaskContext): Option[(OutputMetrics, () => Long)] = { - val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - bytesWrittenCallback.map { b => - (context.taskMetrics().outputMetrics, b) - } - } - - private def maybeUpdateOutputMetrics( - outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], - recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } - } - } - /** * Return an RDD with the keys of each tuple. */ @@ -1258,22 +1163,4 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def valueClass: Class[_] = vt.runtimeClass private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) - - // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation - // setting can take effect: - private def isOutputSpecValidationEnabled: Boolean = { - val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value - val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) - enabledInConf && !validationDisabled - } -} - -private[spark] object PairRDDFunctions { - val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 - - /** - * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case - * basis; see SPARK-4835 for more details. - */ - val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb..e018af35cb18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index eac901d10067..9f800e3a0953 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -239,12 +239,17 @@ private[spark] object ReliableCheckpointRDD extends Logging { val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) val fileInputStream = fs.open(partitionerFilePath, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - val partitioner = Utils.tryWithSafeFinally[Partitioner] { - deserializeStream.readObject[Partitioner] + val partitioner = Utils.tryWithSafeFinally { + val deserializeStream = serializer.deserializeStream(fileInputStream) + Utils.tryWithSafeFinally { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } } { - deserializeStream.close() + fileInputStream.close() } + logDebug(s"Read partitioner from $partitionerFilePath") Some(partitioner) } catch { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 579122868afc..bbc416381490 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -147,6 +147,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def openChannel(uri: String): ReadableByteChannel + /** + * Return if the current thread is a RPC thread. + */ + def isInRPCThread: Boolean } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index a02cf30a5d83..67baabd2cbff 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { + NettyRpcEnv.rpcThreadFlag.value = true try { while (true) { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e51649a1ecce..0b8cd144a216 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv( } + override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value } private[netty] object NettyRpcEnv extends Logging { + private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false) + /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f2517401cb76..7fde34d8974c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1089,7 +1089,8 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && !updates.isZero) { stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) - event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value)) + event.taskInfo.setAccumulables( + acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 3eff8d952bfd..0bd5a6bc59a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -53,13 +53,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { sourceName: String, maybeTruncated: Boolean = false, eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { + val lines = Source.fromInputStream(logData).getLines() + replay(lines, sourceName, maybeTruncated, eventsFilter) + } + /** + * Overloaded variant of [[replay()]] which accepts an iterator of lines instead of an + * [[InputStream]]. Exposed for use by custom ApplicationHistoryProvider implementations. + */ + def replay( + lines: Iterator[String], + sourceName: String, + maybeTruncated: Boolean, + eventsFilter: ReplayEventsFilter): Unit = { var currentLine: String = null var lineNumber: Int = 0 try { - val lineEntries = Source.fromInputStream(logData) - .getLines() + val lineEntries = lines .zipWithIndex .filter { case (line, _) => eventsFilter(line) } @@ -72,6 +83,10 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) } catch { + case e: ClassNotFoundException if KNOWN_REMOVED_CLASSES.contains(e.getMessage) => + // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. + // It's safe since no place uses them. + logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated // the last entry may not be the very last line in the event log, but we treat it @@ -102,4 +117,13 @@ private[spark] object ReplayListenerBus { // utility filter that selects all event logs during replay val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true } + + /** + * Classes that were removed. Structured Streaming doesn't use them any more. However, parsing + * old json may fail and we can just ignore these failures. + */ + val KNOWN_REMOVED_CLASSES = Set( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress", + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated" + ) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index eeb7963c9e61..59680139e7af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ListBuffer - import org.apache.spark.TaskState import org.apache.spark.TaskState.TaskState import org.apache.spark.annotation.DeveloperApi @@ -54,7 +52,13 @@ class TaskInfo( * accumulable to be updated multiple times in a single task or for two accumulables with the * same name but different IDs to exist in a task. */ - val accumulables = ListBuffer[AccumulableInfo]() + def accumulables: Seq[AccumulableInfo] = _accumulables + + private[this] var _accumulables: Seq[AccumulableInfo] = Nil + + private[spark] def setAccumulables(newAccumulables: Seq[AccumulableInfo]): Unit = { + _accumulables = newAccumulables + } /** * The time when the task has completed successfully (including the time to remotely fetch diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cf..368cd30a2e11 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class StandaloneSchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val webUrl = sc.ui.map(_.webUrl).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) // If we're using dynamic allocation, set our initial executor limit to 0 for now. // ExecutorAllocationManager will send the real initial limit to the Master later. @@ -103,8 +103,8 @@ private[spark] class StandaloneSchedulerBackend( } else { None } - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) + val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573d..76af33c1a18d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd238222..d8d5e8958b23 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f631a047a707..b828532aba7a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -82,7 +82,7 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + environmentListener.systemProperties.toMap.getOrElse("user.name", "") } def getAppName: String = appName @@ -94,16 +94,9 @@ private[spark] class SparkUI private ( /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() - logInfo("Stopped Spark web UI at %s".format(appUIAddress)) + logInfo(s"Stopped Spark web UI at $webUrl") } - /** - * Return the application UI host:port. This does not include the scheme (http://). - */ - private[spark] def appUIHostPort = publicHostName + ":" + boundPort - - private[spark] def appUIAddress = s"http://$appUIHostPort" - def getSparkUI(appId: String): Option[SparkUI] = { if (appId == this.appId) Some(this) else None } @@ -136,7 +129,7 @@ private[spark] class SparkUI private ( private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) extends WebUITab(parent, prefix) { - def appName: String = parent.getAppName + def appName: String = parent.appName } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62..57f6f2f0a9be 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) @@ -170,6 +171,7 @@ private[spark] object UIUtils extends Logging { + } def vizHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a05e0efb7a3e..8c801558672f 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -56,8 +56,8 @@ private[spark] abstract class WebUI( private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath - def getTabs: Seq[WebUITab] = tabs.toSeq - def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + def getTabs: Seq[WebUITab] = tabs + def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager /** Attach a tab to this UI, along with all of its attached pages. */ @@ -133,7 +133,7 @@ private[spark] abstract class WebUI( def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ - def bind() { + def bind(): Unit = { assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") @@ -156,7 +156,7 @@ private[spark] abstract class WebUI( def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) /** Stop the server behind this web interface. Only valid after bind(). */ - def stop() { + def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") serverInfo.get.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index a0ef80d9bdae..c6a07445f2a3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -48,6 +48,16 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } }.map { thread => val threadId = thread.threadId + val blockedBy = thread.blockedByThreadId match { + case Some(blockedByThreadId) => + + case None => Text("") + } + val heldLocks = thread.holdingLocks.mkString(", ") + {threadId} {thread.threadName} {thread.threadState} + {blockedBy}{heldLocks} {thread.stackTrace} } @@ -86,6 +97,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage Thread ID Thread Name Thread State + Thread Locks {dumpRows} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 173fc3cf31ce..50e8e2d19e15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -289,8 +289,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val startTime = listener.startTime val endTime = listener.endTime val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse.toSeq - val failedJobs = listener.failedJobs.reverse.toSeq + val completedJobs = listener.completedJobs.reverse + val failedJobs = listener.failedJobs.reverse val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index f4a04609c4c6..9ce8542f0279 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.{HashMap, LinkedHashMap} import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor.{ShuffleReadMetrics, ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet @@ -147,9 +147,8 @@ private[spark] object UIData { memoryBytesSpilled = m.memoryBytesSpilled, diskBytesSpilled = m.diskBytesSpilled, peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics.bytesRead, m.inputMetrics.recordsRead), - outputMetrics = - OutputMetricsUIData(m.outputMetrics.bytesWritten, m.outputMetrics.recordsWritten), + inputMetrics = InputMetricsUIData(m.inputMetrics), + outputMetrics = OutputMetricsUIData(m.outputMetrics), shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) } @@ -171,9 +170,9 @@ private[spark] object UIData { speculative = taskInfo.speculative ) newTaskInfo.gettingResultTime = taskInfo.gettingResultTime - newTaskInfo.accumulables ++= taskInfo.accumulables.filter { + newTaskInfo.setAccumulables(taskInfo.accumulables.filter { accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) - } + }) newTaskInfo.finishTime = taskInfo.finishTime newTaskInfo.failed = taskInfo.failed newTaskInfo @@ -197,8 +196,32 @@ private[spark] object UIData { shuffleWriteMetrics: ShuffleWriteMetricsUIData) case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) + object InputMetricsUIData { + def apply(metrics: InputMetrics): InputMetricsUIData = { + if (metrics.bytesRead == 0 && metrics.recordsRead == 0) { + EMPTY + } else { + new InputMetricsUIData( + bytesRead = metrics.bytesRead, + recordsRead = metrics.recordsRead) + } + } + private val EMPTY = InputMetricsUIData(0, 0) + } case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long) + object OutputMetricsUIData { + def apply(metrics: OutputMetrics): OutputMetricsUIData = { + if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) { + EMPTY + } else { + new OutputMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten) + } + } + private val EMPTY = OutputMetricsUIData(0, 0) + } case class ShuffleReadMetricsUIData( remoteBlocksFetched: Long, @@ -212,17 +235,30 @@ private[spark] object UIData { object ShuffleReadMetricsUIData { def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = { - new ShuffleReadMetricsUIData( - remoteBlocksFetched = metrics.remoteBlocksFetched, - localBlocksFetched = metrics.localBlocksFetched, - remoteBytesRead = metrics.remoteBytesRead, - localBytesRead = metrics.localBytesRead, - fetchWaitTime = metrics.fetchWaitTime, - recordsRead = metrics.recordsRead, - totalBytesRead = metrics.totalBytesRead, - totalBlocksFetched = metrics.totalBlocksFetched - ) + if ( + metrics.remoteBlocksFetched == 0 && + metrics.localBlocksFetched == 0 && + metrics.remoteBytesRead == 0 && + metrics.localBytesRead == 0 && + metrics.fetchWaitTime == 0 && + metrics.recordsRead == 0 && + metrics.totalBytesRead == 0 && + metrics.totalBlocksFetched == 0) { + EMPTY + } else { + new ShuffleReadMetricsUIData( + remoteBlocksFetched = metrics.remoteBlocksFetched, + localBlocksFetched = metrics.localBlocksFetched, + remoteBytesRead = metrics.remoteBytesRead, + localBytesRead = metrics.localBytesRead, + fetchWaitTime = metrics.fetchWaitTime, + recordsRead = metrics.recordsRead, + totalBytesRead = metrics.totalBytesRead, + totalBlocksFetched = metrics.totalBlocksFetched + ) + } } + private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0) } case class ShuffleWriteMetricsUIData( @@ -232,12 +268,17 @@ private[spark] object UIData { object ShuffleWriteMetricsUIData { def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = { - new ShuffleWriteMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten, - writeTime = metrics.writeTime - ) + if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) { + EMPTY + } else { + new ShuffleWriteMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten, + writeTime = metrics.writeTime + ) + } } + private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c11eb3ffa460..4b4d2d10cbf8 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -107,20 +107,20 @@ private[spark] object JsonProtocol { def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted): JValue = { val stageInfo = stageInfoToJson(stageSubmitted.stageInfo) val properties = propertiesToJson(stageSubmitted.properties) - ("Event" -> Utils.getFormattedClassName(stageSubmitted)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) ~ ("Stage Info" -> stageInfo) ~ ("Properties" -> properties) } def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = { val stageInfo = stageInfoToJson(stageCompleted.stageInfo) - ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) ~ ("Stage Info" -> stageInfo) } def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = { val taskInfo = taskStart.taskInfo - ("Event" -> Utils.getFormattedClassName(taskStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) ~ ("Stage ID" -> taskStart.stageId) ~ ("Stage Attempt ID" -> taskStart.stageAttemptId) ~ ("Task Info" -> taskInfoToJson(taskInfo)) @@ -128,7 +128,7 @@ private[spark] object JsonProtocol { def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = { val taskInfo = taskGettingResult.taskInfo - ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) ~ ("Task Info" -> taskInfoToJson(taskInfo)) } @@ -137,7 +137,7 @@ private[spark] object JsonProtocol { val taskInfo = taskEnd.taskInfo val taskMetrics = taskEnd.taskMetrics val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing - ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) ~ ("Stage ID" -> taskEnd.stageId) ~ ("Stage Attempt ID" -> taskEnd.stageAttemptId) ~ ("Task Type" -> taskEnd.taskType) ~ @@ -148,7 +148,7 @@ private[spark] object JsonProtocol { def jobStartToJson(jobStart: SparkListenerJobStart): JValue = { val properties = propertiesToJson(jobStart.properties) - ("Event" -> Utils.getFormattedClassName(jobStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) ~ ("Job ID" -> jobStart.jobId) ~ ("Submission Time" -> jobStart.time) ~ ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 @@ -158,7 +158,7 @@ private[spark] object JsonProtocol { def jobEndToJson(jobEnd: SparkListenerJobEnd): JValue = { val jobResult = jobResultToJson(jobEnd.jobResult) - ("Event" -> Utils.getFormattedClassName(jobEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobEnd) ~ ("Job ID" -> jobEnd.jobId) ~ ("Completion Time" -> jobEnd.time) ~ ("Job Result" -> jobResult) @@ -170,7 +170,7 @@ private[spark] object JsonProtocol { val sparkProperties = mapToJson(environmentDetails("Spark Properties").toMap) val systemProperties = mapToJson(environmentDetails("System Properties").toMap) val classpathEntries = mapToJson(environmentDetails("Classpath Entries").toMap) - ("Event" -> Utils.getFormattedClassName(environmentUpdate)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.environmentUpdate) ~ ("JVM Information" -> jvmInformation) ~ ("Spark Properties" -> sparkProperties) ~ ("System Properties" -> systemProperties) ~ @@ -179,7 +179,7 @@ private[spark] object JsonProtocol { def blockManagerAddedToJson(blockManagerAdded: SparkListenerBlockManagerAdded): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId) - ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~ ("Block Manager ID" -> blockManagerId) ~ ("Maximum Memory" -> blockManagerAdded.maxMem) ~ ("Timestamp" -> blockManagerAdded.time) @@ -187,18 +187,18 @@ private[spark] object JsonProtocol { def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId) - ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerRemoved) ~ ("Block Manager ID" -> blockManagerId) ~ ("Timestamp" -> blockManagerRemoved.time) } def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = { - ("Event" -> Utils.getFormattedClassName(unpersistRDD)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.unpersistRDD) ~ ("RDD ID" -> unpersistRDD.rddId) } def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = { - ("Event" -> Utils.getFormattedClassName(applicationStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationStart) ~ ("App Name" -> applicationStart.appName) ~ ("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~ ("Timestamp" -> applicationStart.time) ~ @@ -208,33 +208,33 @@ private[spark] object JsonProtocol { } def applicationEndToJson(applicationEnd: SparkListenerApplicationEnd): JValue = { - ("Event" -> Utils.getFormattedClassName(applicationEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationEnd) ~ ("Timestamp" -> applicationEnd.time) } def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { - ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorAdded) ~ ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { - ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorRemoved) ~ ("Timestamp" -> executorRemoved.time) ~ ("Executor ID" -> executorRemoved.executorId) ~ ("Removed Reason" -> executorRemoved.reason) } def logStartToJson(logStart: SparkListenerLogStart): JValue = { - ("Event" -> Utils.getFormattedClassName(logStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.logStart) ~ ("Spark Version" -> SPARK_VERSION) } def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId val accumUpdates = metricsUpdate.accumUpdates - ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~ ("Executor ID" -> execId) ~ ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ @@ -485,7 +485,7 @@ private[spark] object JsonProtocol { * JSON deserialization methods for SparkListenerEvents | * ---------------------------------------------------- */ - def sparkEventFromJson(json: JValue): SparkListenerEvent = { + private object SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES { val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted) val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted) val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart) @@ -503,6 +503,10 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + } + + def sparkEventFromJson(json: JValue): SparkListenerEvent = { + import SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES._ (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -540,7 +544,8 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] - val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val stageAttemptId = + Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -552,7 +557,8 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] - val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val stageAttemptId = + Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -662,20 +668,22 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = (json \ "Details").extractOpt[String].getOrElse("") + val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) - val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson) - case None => Seq[AccumulableInfo]() + val accumulatedValues = { + Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + case Some(values) => values.map(accumulableInfoFromJson) + case None => Seq[AccumulableInfo]() + } } val stageInfo = new StageInfo( @@ -692,17 +700,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1) + val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] - val executorId = (json \ "Executor ID").extract[String] - val host = (json \ "Host").extract[String] + val executorId = (json \ "Executor ID").extract[String].intern() + val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false) + val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = (json \ "Killed").extractOpt[Boolean].getOrElse(false) - val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match { + val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -713,18 +721,19 @@ private[spark] object JsonProtocol { taskInfo.finishTime = finishTime taskInfo.failed = failed taskInfo.killed = killed - accumulables.foreach { taskInfo.accumulables += _ } + taskInfo.setAccumulables(accumulables) taskInfo } def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = (json \ "Name").extractOpt[String] + val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) - val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) - val metadata = (json \ "Metadata").extractOpt[String] + val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val countFailedValues = + Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -782,9 +791,11 @@ private[spark] object JsonProtocol { readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incLocalBytesRead( + Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) - readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incRecordsRead( + Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } @@ -793,8 +804,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) - writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written") - .extractOpt[Long].getOrElse(0L)) + writeMetrics.incRecordsWritten( + Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } @@ -802,14 +813,16 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) - outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L)) + outputMetrics.setRecordsWritten( + Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incRecordsRead( + Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks @@ -824,7 +837,7 @@ private[spark] object JsonProtocol { metrics } - def taskEndReasonFromJson(json: JValue): TaskEndReason = { + private object TASK_END_REASON_FORMATTED_CLASS_NAMES { val success = Utils.getFormattedClassName(Success) val resubmitted = Utils.getFormattedClassName(Resubmitted) val fetchFailed = Utils.getFormattedClassName(FetchFailed) @@ -834,6 +847,10 @@ private[spark] object JsonProtocol { val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) + } + + def taskEndReasonFromJson(json: JValue): TaskEndReason = { + import TASK_END_REASON_FORMATTED_CLASS_NAMES._ (json \ "Reason").extract[String] match { case `success` => Success @@ -850,7 +867,8 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") - val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull + val fullStackTrace = + Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) @@ -885,15 +903,19 @@ private[spark] object JsonProtocol { if (json == JNothing) { return null } - val executorId = (json \ "Executor ID").extract[String] - val host = (json \ "Host").extract[String] + val executorId = (json \ "Executor ID").extract[String].intern() + val host = (json \ "Host").extract[String].intern() val port = (json \ "Port").extract[Int] BlockManagerId(executorId, host, port) } - def jobResultFromJson(json: JValue): JobResult = { + private object JOB_RESULT_FORMATTED_CLASS_NAMES { val jobSucceeded = Utils.getFormattedClassName(JobSucceeded) val jobFailed = Utils.getFormattedClassName(JobFailed) + } + + def jobResultFromJson(json: JValue): JobResult = { + import JOB_RESULT_FORMATTED_CLASS_NAMES._ (json \ "Result").extract[String] match { case `jobSucceeded` => JobSucceeded diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala index d4e0ad93b966..b1217980faf1 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace( threadId: Long, threadName: String, threadState: Thread.State, - stackTrace: String) + stackTrace: String, + blockedByThreadId: Option[Long], + blockedByLock: String, + holdingLocks: Seq[String]) + diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6027b07c0fee..1de66af632a8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.ManagementFactory +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -2096,15 +2096,41 @@ private[spark] object Utils extends Logging { } } + private implicit class Lock(lock: LockInfo) { + def lockString: String = { + lock match { + case monitor: MonitorInfo => + s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" + case _ => + s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" + } + } + } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ def getThreadDump(): Array[ThreadStackTrace] = { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, - threadInfo.getThreadState, stackTrace) + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) + ++ threadInfo.getLockedMonitors.map(_.lockString) + ).toSet + + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, + stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) } } @@ -2513,6 +2539,8 @@ private[util] object CallerContext extends Logging { val callerContextSupported: Boolean = { SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { try { + // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in + // master Maven build, so do not use it before resolving SPARK-17714. // scalastyle:off classforname Class.forName("org.apache.hadoop.ipc.CallerContext") Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") @@ -2578,6 +2606,8 @@ private[spark] class CallerContext( def setCurrentContext(): Unit = { if (CallerContext.callerContextSupported) { try { + // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in + // master Maven build, so do not use it before resolving SPARK-17714. // scalastyle:off classforname val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 6b74a29aceda..bcb95b416dd2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -140,16 +140,16 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k.equals(curKey)) { - val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) - data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] - return newValue - } else if (curKey.eq(null)) { + if (curKey.eq(null)) { val newValue = updateFunc(false, null.asInstanceOf[V]) data(2 * pos) = k data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] incrementSize() return newValue + } else if (k.eq(curKey) || k.equals(curKey)) { + val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue } else { val delta = i pos = (pos + delta) & mask diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 0f6a425e3db9..60f6f537c1d5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -48,7 +48,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( require(initialCapacity <= OpenHashSet.MAX_CAPACITY, s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements") - require(initialCapacity >= 1, "Invalid initial capacity") + require(initialCapacity >= 0, "Invalid initial capacity") require(loadFactor < 1.0, "Load factor must be less than 1.0") require(loadFactor > 0.0, "Load factor must be greater than 0.0") @@ -271,8 +271,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() private def nextPowerOf2(n: Int): Int = { - val highBit = Integer.highestOneBit(n) - if (highBit == n) n else highBit << 1 + if (n == 0) { + 1 + } else { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } } } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 5c4238c0381a..1f263df57c85 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -18,7 +18,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.internal.Logging @@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy( } @volatile private var nextRolloverTime = calculateNextRolloverTime() - private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US) /** Should rollover if current time has exceeded next rollover time */ def shouldRollover(bytesToBeWritten: Long): Boolean = { @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US) /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 6724af952505..0f78871ed35a 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -44,7 +44,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[So { implicit val defaultTimeout = timeout(10000 millis) val conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") @@ -232,7 +232,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { // Verify that checkpoints are NOT cleaned up if the config is not enabled sc.stop() val conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("cleanupCheckpoint") .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false") sc = new SparkContext(conf) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index cc52bb1d23cd..89f0b1cb5b56 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -58,10 +58,15 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { nums.saveAsTextFile(outputDir) // Read the plain text file and check it's OK val outputFile = new File(outputDir, "part-00000") - val content = Source.fromFile(outputFile).mkString - assert(content === "1\n2\n3\n4\n") - // Also try reading it in as a text file RDD - assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + val bufferSrc = Source.fromFile(outputFile) + Utils.tryWithSafeFinally { + val content = bufferSrc.mkString + assert(content === "1\n2\n3\n4\n") + // Also try reading it in as a text file RDD + assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + } { + bufferSrc.close() + } } test("text files (compressed)") { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 915d7a1b8b16..5457a066d3c0 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -67,7 +67,7 @@ class HeartbeatReceiverSuite override def beforeEach(): Unit = { super.beforeEach() val conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("test") .set("spark.dynamicAllocation.testing", "true") sc = spy(new SparkContext(conf)) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index a3490fc79e45..5b89eaae032a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -47,7 +47,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft test("local mode, FIFO scheduler") { val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local[2]", "test", conf) + sc = new SparkContext("local[4]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -58,7 +58,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() conf.set("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local[2]", "test", conf) + sc = new SparkContext("local[4]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -115,7 +115,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } test("job group") { - sc = new SparkContext("local[2]", "test") + sc = new SparkContext("local[4]", "test") // Add a listener to release the semaphore once any tasks are launched. val sem = new Semaphore(0) @@ -145,7 +145,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } test("inherited job group (SPARK-6629)") { - sc = new SparkContext("local[2]", "test") + sc = new SparkContext("local[4]", "test") // Add a listener to release the semaphore once any tasks are launched. val sem = new Semaphore(0) @@ -180,7 +180,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } test("job group with interruption") { - sc = new SparkContext("local[2]", "test") + sc = new SparkContext("local[4]", "test") // Add a listener to release the semaphore once any tasks are launched. val sem = new Semaphore(0) @@ -215,7 +215,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - sc = new SparkContext("local[2]", "test") + sc = new SparkContext("local[4]", "test") sc.addSparkListener(new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart) { sem1.release() diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 83906cff123b..21b2726d7e1d 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -132,8 +132,8 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst test("SparkContext property overriding") { val conf = new SparkConf(false).setMaster("local").setAppName("My app") - sc = new SparkContext("local[2]", "My other app", conf) - assert(sc.master === "local[2]") + sc = new SparkContext("local[4]", "My other app", conf) + assert(sc.master === "local[4]") assert(sc.appName === "My other app") } diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 13cba94578a6..005587051b6a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{ResetSystemProperties, Utils} class RPackageUtilsSuite extends SparkFunSuite @@ -74,9 +74,13 @@ class RPackageUtilsSuite val deps = Seq(dep1, dep2).mkString(",") IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => val jars = Seq(main, dep1, dep2).map(c => new JarFile(getJarPath(c, new File(new URI(repo))))) - assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") - assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") - assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + Utils.tryWithSafeFinally { + assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") + assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") + assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + } { + jars.foreach(_.close()) + } } } @@ -131,7 +135,7 @@ class RPackageUtilsSuite test("SparkR zipping works properly") { val tempDir = Files.createTempDir() - try { + Utils.tryWithSafeFinally { IvyTestUtils.writeFile(tempDir, "test.R", "abc") val fakeSparkRDir = new File(tempDir, "SparkR") assert(fakeSparkRDir.mkdirs()) @@ -144,14 +148,19 @@ class RPackageUtilsSuite IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") assert(finalZip.exists()) - val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq - assert(entries.contains("/test.R")) - assert(entries.contains("/SparkR/abc.R")) - assert(entries.contains("/SparkR/DESCRIPTION")) - assert(!entries.contains("/package.zip")) - assert(entries.contains("/packageTest/def.R")) - assert(entries.contains("/packageTest/DESCRIPTION")) - } finally { + val zipFile = new ZipFile(finalZip) + Utils.tryWithSafeFinally { + val entries = zipFile.entries().asScala.map(_.getName).toSeq + assert(entries.contains("/test.R")) + assert(entries.contains("/SparkR/abc.R")) + assert(entries.contains("/SparkR/DESCRIPTION")) + assert(!entries.contains("/package.zip")) + assert(entries.contains("/packageTest/def.R")) + assert(entries.contains("/packageTest/DESCRIPTION")) + } { + zipFile.close() + } + } { FileUtils.deleteDirectory(tempDir) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index a5eda7b5a5a7..2c41c432d1fe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -449,8 +449,14 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) val bstream = new BufferedOutputStream(cstream) if (isNewFormat) { - EventLoggingListener.initEventLog(new FileOutputStream(file)) + val newFormatStream = new FileOutputStream(file) + Utils.tryWithSafeFinally { + EventLoggingListener.initEventLog(newFormatStream) + } { + newFormatStream.close() + } } + val writer = new OutputStreamWriter(bstream, StandardCharsets.UTF_8) Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a595bc174a31..715811a46f42 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -29,6 +29,8 @@ import com.codahale.metrics.Counter import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.eclipse.jetty.proxy.ProxyServlet +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ @@ -258,8 +260,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { - val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) val request = mock[HttpServletRequest] @@ -267,7 +268,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // when System.setProperty("spark.ui.proxyBase", uiRoot) val response = page.render(request) - System.setProperty("spark.ui.proxyBase", Option(proxyBaseBeforeTest).getOrElse("")) // then val urls = response \\ "@href" map (_.toString) @@ -275,6 +275,80 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + val uiRoot = "/testwebproxybase" + System.setProperty("spark.ui.proxyBase", uiRoot) + + server.stop() + + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", logDir) + .set("spark.history.fs.update.interval", "0") + .set("spark.testing", "true") + + provider = new FsHistoryProvider(conf) + provider.checkForLogs() + val securityManager = new SecurityManager(conf) + + server = new HistoryServer(conf, provider, securityManager, 18080) + server.initialize() + server.bind() + + val port = server.boundPort + + val servlet = new ProxyServlet { + override def rewriteTarget(request: HttpServletRequest): String = { + // servlet acts like a proxy that redirects calls made on + // spark.ui.proxyBase context path to the normal servlet handlers operating off "/" + val sb = request.getRequestURL() + + if (request.getQueryString() != null) { + sb.append(s"?${request.getQueryString()}") + } + + val proxyidx = sb.indexOf(uiRoot) + sb.delete(proxyidx, proxyidx + uiRoot.length).toString + } + } + + val contextHandler = new ServletContextHandler + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(uiRoot) + contextHandler.addServlet(holder, "/") + server.attachHandler(contextHandler) + + implicit val webDriver: WebDriver = new HtmlUnitDriver(true) { + getWebClient.getOptions.setThrowExceptionOnScriptError(false) + } + + try { + val url = s"http://localhost:$port" + + go to s"$url$uiRoot" + + // expect the ajax call to finish in 5 seconds + implicitlyWait(org.scalatest.time.Span(5, org.scalatest.time.Seconds)) + + // once this findAll call returns, we know the ajax load of the table completed + findAll(ClassNameQuery("odd")) + + val links = findAll(TagNameQuery("a")) + .map(_.attribute("href")) + .filter(_.isDefined) + .map(_.get) + .filter(_.startsWith(url)).toList + + // there are atleast some URL links that were generated via javascript, + // and they all contain the spark.ui.proxyBase (uiRoot) + links.length should be > 4 + all(links) should startWith(url + uiRoot) + } finally { + contextHandler.stop() + quit() + } + + } + test("incomplete apps get refreshed") { implicit val webDriver: WebDriver = new HtmlUnitDriver diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 58664e77d24a..ef5845a77c11 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -36,7 +36,7 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim override def beforeAll() { super.beforeAll() - sc = new SparkContext("local[2]", "test") + sc = new SparkContext("local[4]", "test") } override def afterAll() { diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 2802cd975292..5ff61b35c8bc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -28,7 +28,7 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { override def beforeEach(): Unit = { super.beforeEach() - sc = new SparkContext("local[2]", "test") + sc = new SparkContext("local[4]", "test") } test("transform storage level") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index b0d69de6e2ef..02df157be377 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -516,10 +516,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") /* - Check that configurable formats get configured: - ConfigTestFormat throws an exception if we try to write - to it when setConf hasn't been called first. - Assertion is in ConfigTestFormat.getRecordWriter. + * Check that configurable formats get configured: + * ConfigTestFormat throws an exception if we try to write + * to it when setConf hasn't been called first. + * Assertion is in ConfigTestFormat.getRecordWriter. */ pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") } @@ -544,7 +544,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val e = intercept[SparkException] { pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored") } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") @@ -725,8 +725,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } /* - These classes are fakes for testing - "saveNewAPIHadoopFile should call setConf if format is configurable". + These classes are fakes for testing saveAsHadoopFile/saveNewAPIHadoopFile. Unfortunately, they have to be top level classes, and not defined in the test method, because otherwise Scala won't generate no-args constructors and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index acdf21df9a16..aa0705987d83 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { verify(endpoint, never()).onDisconnected(any()) verify(endpoint, never()).onNetworkError(any(), any()) } + + test("isInRPCThread") { + val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply(rpcEnv.isInRPCThread) + } + }) + assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true) + assert(env.isInRPCThread === false) + env.stop(rpcEndpointRef) + } } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 7f4859206e25..8a5ec37eeb66 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -202,8 +202,6 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // Make sure expected events exist in the log file. val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) - val logStart = SparkListenerLogStart(SPARK_VERSION) - val lines = readLines(logData) val eventSet = mutable.Set( SparkListenerApplicationStart, SparkListenerBlockManagerAdded, @@ -216,19 +214,25 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerTaskStart, SparkListenerTaskEnd, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) - lines.foreach { line => - eventSet.foreach { event => - if (line.contains(event)) { - val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line)) - val eventType = Utils.getFormattedClassName(parsedEvent) - if (eventType == event) { - eventSet.remove(event) + Utils.tryWithSafeFinally { + val logStart = SparkListenerLogStart(SPARK_VERSION) + val lines = readLines(logData) + lines.foreach { line => + eventSet.foreach { event => + if (line.contains(event)) { + val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line)) + val eventType = Utils.getFormattedClassName(parsedEvent) + if (eventType == event) { + eventSet.remove(event) + } } } } + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) + } { + logData.close() } - assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) - assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) } private def readLines(in: InputStream): Seq[String] = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 9e472f900b65..ee95e4ff7dbc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -183,9 +183,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // ensure we reset the classloader after the test completes val originalClassLoader = Thread.currentThread.getContextClassLoader - try { + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + Utils.tryWithSafeFinally { // load the exception from the jar - val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) loader.addURL(jarFile.toURI.toURL) Thread.currentThread().setContextClassLoader(loader) val excClass: Class[_] = Utils.classForName("repro.MyException") @@ -209,8 +209,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) - } finally { + } { Thread.currentThread.setContextClassLoader(originalClassLoader) + loader.close() } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e5d408a16736..f4786e3931c9 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -473,7 +473,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -486,7 +486,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -620,7 +620,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B test("live UI json application list") { withSpark(newSparkContext()) { sc => val appListRawJson = HistoryServerSuite.getUrl(new URL( - sc.ui.get.appUIAddress + "/api/v1/applications")) + sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) val attempts = (appListJsonAst \ "attempts").children @@ -640,7 +640,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + val stage0 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) @@ -651,7 +651,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("{\n label="groupBy";\n " + "2 [label="MapPartitionsRDD [2]")) - val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + val stage1 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) @@ -662,7 +662,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage1.contains("{\n label="groupBy";\n " + "5 [label="MapPartitionsRDD [5]")) - val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + val stage2 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) @@ -687,7 +687,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def goToUi(ui: SparkUI, path: String): Unit = { - go to (ui.appUIAddress.stripSuffix("/") + path) + go to (ui.webUrl.stripSuffix("/") + path) } def parseDate(json: JValue): Long = { @@ -699,6 +699,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) + new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 4abcfb7e5191..68c7657cb315 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -66,7 +66,7 @@ class UISuite extends SparkFunSuite { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.get.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -176,19 +176,18 @@ class UISuite extends SparkFunSuite { } } - test("verify appUIAddress contains the scheme") { + test("verify webUrl contains the scheme") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val uiAddress = ui.appUIAddress - val uiHostPort = ui.appUIHostPort - assert(uiAddress.equals("http://" + uiHostPort)) + val uiAddress = ui.webUrl + assert(uiAddress.startsWith("http://") || uiAddress.startsWith("https://")) } } - test("verify appUIAddress contains the port") { + test("verify webUrl contains the port") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val splitUIAddress = ui.appUIAddress.split(':') + val splitUIAddress = ui.webUrl.split(':') val boundPort = ui.boundPort assert(splitUIAddress(2).toInt == boundPort) } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 8418fa74d2c6..da853f1be8b9 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -403,7 +403,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with internal = false, countFailedValues = false, metadata = None) - taskInfo.accumulables ++= Seq(internalAccum, sqlAccum, userAccum) + taskInfo.setAccumulables(List(internalAccum, sqlAccum, userAccum)) val newTaskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) assert(newTaskInfo.accumulables === Seq(userAccum)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index d5146d70ebaa..85da79180fd0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -788,11 +788,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = { val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL, speculative) - val (acc1, acc2, acc3) = - (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true)) - taskInfo.accumulables += acc1 - taskInfo.accumulables += acc2 - taskInfo.accumulables += acc3 + taskInfo.setAccumulables( + List(makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true))) taskInfo } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 15ef32f21d90..feacfb7642f2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -264,7 +264,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val hour = minute * 60 def str: (Long) => String = Utils.msDurationToString(_) - val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() + val sep = new DecimalFormatSymbols(Locale.US).getDecimalSeparator assert(str(123) === "123 ms") assert(str(second) === "1" + sep + "0 s") diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 3066e9996abd..335ecb9320ab 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -49,9 +49,6 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { intercept[IllegalArgumentException] { new OpenHashMap[String, Int](-1) } - intercept[IllegalArgumentException] { - new OpenHashMap[String, String](0) - } } test("primitive value") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 2607a543dd61..210bc5c09974 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -176,4 +176,9 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.size === 1000) assert(set.capacity > 1000) } + + test("SPARK-18200 Support zero as an initial set size") { + val set = new OpenHashSet[Long](0) + assert(set.size === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 508e737b725b..f5ee428020fd 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -49,9 +49,6 @@ class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { intercept[IllegalArgumentException] { new PrimitiveKeyOpenHashMap[Int, Int](-1) } - intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](0) - } } test("basic operations") { diff --git a/docs/building-spark.md b/docs/building-spark.md index ebe46a42a15c..2b404bd3e116 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -13,6 +13,7 @@ redirect_from: "building-with-maven.html" The Maven-based build is the build of reference for Apache Spark. Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. +Note that support for Java 7 is deprecated as of Spark 2.0.0 and may be removed in Spark 2.2.0. ### Setting up Maven's Memory Usage @@ -79,6 +80,9 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro +Note that support for versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0 and may be +removed in Spark 2.2.0. + You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. @@ -129,6 +133,8 @@ To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` prop ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package + +Note that support for Scala 2.10 is deprecated as of Spark 2.1.0 and may be removed in Spark 2.2.0. ## Building submodules individually diff --git a/docs/configuration.md b/docs/configuration.md index 6600cb6c0ac0..d0acd944dd6b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -767,7 +767,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.referenceTracking - true (false when using Spark SQL Thrift Server) + true Whether to track references to the same object when serializing data with Kryo, which is necessary if your object graphs have loops and useful for efficiency if they contain multiple @@ -838,8 +838,7 @@ Apart from these, the following properties are also available, and may be useful spark.serializer - org.apache.spark.serializer.
JavaSerializer (org.apache.spark.serializer.
- KryoSerializer when using Spark SQL Thrift Server) + org.apache.spark.serializer.
JavaSerializer Class to use for serializing objects that will be sent over the network or need to be cached @@ -1035,6 +1034,22 @@ Apart from these, the following properties are also available, and may be useful its contents do not match those of the source. + + spark.files.maxPartitionBytes + 134217728 (128 MB) + + The maximum number of bytes to pack into a single partition when reading files. + + + + spark.files.openCostInBytes + 4194304 (4 MB) + + The estimated cost to open a file, measured by the number of bytes could be scanned in the same + time. This is used when putting multiple files into a partition. It is better to over estimate, + then the partitions with small files will be faster than partitions with bigger files. + + spark.hadoop.cloneConf false @@ -1890,6 +1905,21 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.r.shell.command is used for sparkR shell while spark.r.driver.command is used for running R script. + + spark.r.backendConnectionTimeout + 6000 + + Connection timeout set by R process on its connection to RBackend in seconds. + + + + spark.r.heartBeatInterval + 100 + + Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout. + + + #### Deploy diff --git a/docs/index.md b/docs/index.md index a7a92f6c4f6d..fe51439ae08d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -28,6 +28,10 @@ Spark runs on Java 7+, Python 2.6+/3.4+ and R 3.1+. For the Scala API, Spark {{s uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). +Note that support for Java 7 and Python 2.6 are deprecated as of Spark 2.0.0, and support for +Scala 2.10 and versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0, and may be +removed in Spark 2.2.0. + # Running the Examples and Shell Spark comes with several sample programs. Scala, Java, Python and R examples are in the diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index bb2e404330cc..b10793d83ec6 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -46,7 +46,7 @@ parameter to select between these two algorithms, or leave it unset and Spark wi For more background and more details about the implementation of binomial logistic regression, refer to the documentation of [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). -**Example** +**Examples** The following example shows how to train binomial and multinomial logistic regression models for binary classification with elastic net regularization. `elasticNetParam` corresponds to @@ -137,7 +137,7 @@ We minimize the weighted negative log-likelihood, using a multinomial response m For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multinomial_logistic_regression#As_a_log-linear_model). -**Example** +**Examples** The following example shows how to train a multiclass logistic regression model with elastic net regularization. @@ -164,7 +164,7 @@ model with elastic net regularization. Decision trees are a popular family of classification and regression methods. More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. @@ -201,7 +201,7 @@ More details on parameters can be found in the [Python API documentation](api/py Random forests are a popular family of classification and regression methods. More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. @@ -234,7 +234,7 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. @@ -284,7 +284,7 @@ The number of nodes `$N$` in the output layer corresponds to the number of class MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. -**Example** +**Examples**
@@ -311,7 +311,7 @@ MLPC employs backpropagation for learning the model. We use the logistic loss fu Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. -**Example** +**Examples** The example below demonstrates how to load the [Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. @@ -348,7 +348,7 @@ naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-c and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). -**Example** +**Examples**
@@ -383,7 +383,7 @@ summaries is similar to the logistic regression case. > When fitting LinearRegressionModel without intercept on dataset with constant nonzero column by "l-bfgs" solver, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM. -**Example** +**Examples** The following example demonstrates training an elastic net regularized linear @@ -511,7 +511,7 @@ others. -**Example** +**Examples** The following example demonstrates training a GLM with a Gaussian response and identity link function and extracting model summary statistics. @@ -544,7 +544,7 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. Decision trees are a popular family of classification and regression methods. More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. @@ -579,7 +579,7 @@ More details on parameters can be found in the [Python API documentation](api/py Random forests are a popular family of classification and regression methods. More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. @@ -612,7 +612,7 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). -**Example** +**Examples** Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not be true in general. @@ -700,7 +700,7 @@ The implementation matches the result from R's survival function > When fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg. -**Example** +**Examples**
@@ -765,7 +765,7 @@ is treated as piecewise linear function. The rules for prediction therefore are: predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. -### Examples +**Examples**
diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index 8a0a61cb595e..eedacb12bc46 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -65,7 +65,7 @@ called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -### Example +**Examples**
@@ -94,6 +94,8 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering. and generates a `LDAModel` as the base model. Expert users may cast a `LDAModel` generated by `EMLDAOptimizer` to a `DistributedLDAModel` if needed. +**Examples** +
@@ -128,7 +130,7 @@ Bisecting K-means can often be much faster than regular K-means, but it will gen `BisectingKMeans` is implemented as an `Estimator` and generates a `BisectingKMeansModel` as the base model. -### Example +**Examples**
@@ -210,7 +212,7 @@ model. -### Example +**Examples**
diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 1d02d6933cb4..4d19b4069a1f 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -59,7 +59,7 @@ This approach is named "ALS-WR" and discussed in the paper It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. -## Examples +**Examples**
diff --git a/docs/ml-features.md b/docs/ml-features.md index 64c6a160239c..19ec5746978a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -112,6 +112,8 @@ can then be used as features for prediction, document similarity calculations, e Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. +**Examples** + In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
@@ -220,6 +222,8 @@ for more details on the API. Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result. +**Examples** +
@@ -321,6 +325,8 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t `NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +**Examples** +
@@ -358,6 +364,8 @@ for binarization. Feature values greater than the threshold are binarized to 1.0 to or less than the threshold are binarized to 0.0. Both Vector and Double types are supported for `inputCol`. +**Examples** +
@@ -388,6 +396,8 @@ for more details on the API. [PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components. +**Examples** +
@@ -418,6 +428,8 @@ for more details on the API. [Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality. The example below shows how to expand your features into a 3-degree polynomial space. +**Examples** +
@@ -458,6 +470,8 @@ for the transform is unitary. No shift is applied to the transformed sequence (e.g. the $0$th element of the transformed sequence is the $0$th DCT coefficient and _not_ the $N/2$th). +**Examples** +
@@ -663,6 +677,8 @@ for more details on the API. [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +**Examples** +
@@ -701,6 +717,8 @@ It can both automatically decide which features are categorical and convert orig Indexing categorical features allows algorithms such as Decision Trees and Tree Ensembles to treat categorical features appropriately, improving performance. +**Examples** + In the example below, we read in a dataset of labeled points and then use `VectorIndexer` to decide which features should be treated as categorical. We transform the categorical feature values to their indices. This transformed data could then be passed to algorithms such as `DecisionTreeRegressor` that handle categorical features.
@@ -729,11 +747,65 @@ for more details on the API.
+## Interaction + +`Interaction` is a `Transformer` which takes vector or double-valued columns, and generates a single vector column that contains the product of all combinations of one value from each input column. + +For example, if you have 2 vector type columns each of which has 3 dimensions as input columns, then then you'll get a 9-dimensional vector as the output column. + +**Examples** + +Assume that we have the following DataFrame with the columns "id1", "vec1", and "vec2": + +~~~~ + id1|vec1 |vec2 + ---|--------------|-------------- + 1 |[1.0,2.0,3.0] |[8.0,4.0,5.0] + 2 |[4.0,3.0,8.0] |[7.0,9.0,8.0] + 3 |[6.0,1.0,9.0] |[2.0,3.0,6.0] + 4 |[10.0,8.0,6.0]|[9.0,4.0,5.0] + 5 |[9.0,2.0,7.0] |[10.0,7.0,3.0] + 6 |[1.0,1.0,4.0] |[2.0,8.0,4.0] +~~~~ + +Applying `Interaction` with those input columns, +then `interactedCol` as the output column contains: + +~~~~ + id1|vec1 |vec2 |interactedCol + ---|--------------|--------------|------------------------------------------------------ + 1 |[1.0,2.0,3.0] |[8.0,4.0,5.0] |[8.0,4.0,5.0,16.0,8.0,10.0,24.0,12.0,15.0] + 2 |[4.0,3.0,8.0] |[7.0,9.0,8.0] |[56.0,72.0,64.0,42.0,54.0,48.0,112.0,144.0,128.0] + 3 |[6.0,1.0,9.0] |[2.0,3.0,6.0] |[36.0,54.0,108.0,6.0,9.0,18.0,54.0,81.0,162.0] + 4 |[10.0,8.0,6.0]|[9.0,4.0,5.0] |[360.0,160.0,200.0,288.0,128.0,160.0,216.0,96.0,120.0] + 5 |[9.0,2.0,7.0] |[10.0,7.0,3.0]|[450.0,315.0,135.0,100.0,70.0,30.0,350.0,245.0,105.0] + 6 |[1.0,1.0,4.0] |[2.0,8.0,4.0] |[12.0,48.0,24.0,12.0,48.0,24.0,48.0,192.0,96.0] +~~~~ + +
+
+ +Refer to the [Interaction Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Interaction) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/InteractionExample.scala %} +
+ +
+ +Refer to the [Interaction Java docs](api/java/org/apache/spark/ml/feature/Interaction.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaInteractionExample.java %} +
+
## Normalizer `Normalizer` is a `Transformer` which transforms a dataset of `Vector` rows, normalizing each `Vector` to have unit norm. It takes parameter `p`, which specifies the [p-norm](http://en.wikipedia.org/wiki/Norm_%28mathematics%29#p-norm) used for normalization. ($p = 2$ by default.) This normalization can help standardize your input data and improve the behavior of learning algorithms. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^1$ norm and unit $L^\infty$ norm.
@@ -774,6 +846,8 @@ for more details on the API. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
@@ -819,6 +893,8 @@ For the case `$E_{max} == E_{min}$`, `$Rescaled(e_i) = 0.5 * (max + min)$` Note that since zero values will probably be transformed to non-zero values, output of the transformer will be `DenseVector` even for sparse input. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1].
@@ -860,6 +936,8 @@ data, and thus does not destroy any sparsity. `MaxAbsScaler` computes summary statistics on a data set and produces a `MaxAbsScalerModel`. The model can then transform each feature individually to range [-1, 1]. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [-1, 1].
@@ -903,6 +981,8 @@ Note also that the splits that you provided have to be in strictly increasing or More details can be found in the API docs for [Bucketizer](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer). +**Examples** + The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
@@ -951,6 +1031,8 @@ v_N \end{pmatrix} \]` +**Examples** + This example below demonstrates how to transform vectors using a transforming vector value.
@@ -1338,14 +1420,14 @@ for more details on the API. `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. **Examples** diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 2ca90c7092fd..e4b070331db4 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -62,7 +62,7 @@ To help construct the parameter grid, users can use the [`ParamGridBuilder`](api After identifying the best `ParamMap`, `CrossValidator` finally re-fits the `Estimator` using the best `ParamMap` and the entire dataset. -## Example: model selection via cross-validation +**Examples: model selection via cross-validation** The following example demonstrates using `CrossValidator` to select from a grid of parameters. @@ -102,7 +102,7 @@ It splits the dataset into these two parts using the `trainRatio` parameter. For Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. -## Example: model selection via train validation split +**Examples: model selection via train validation split**
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 87e1e027e945..42568c312e70 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -227,22 +227,19 @@ both speed and statistical learning behavior. [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. The number of features to select can be tuned using a held-out validation set. ### Model Fitting -`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that -the selector will select. - The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 7516579ec6db..b9a2110b602a 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -59,6 +59,8 @@ Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. +Note that support for Java 7 is deprecated as of Spark 2.0.0 and may be removed in Spark 2.2.0. + To write a Spark application in Java, you need to add a dependency on Spark. Spark is available through Maven Central at: groupId = org.apache.spark @@ -87,6 +89,8 @@ import org.apache.spark.SparkConf Spark {{site.SPARK_VERSION}} works with Python 2.6+ or Python 3.4+. It can use the standard CPython interpreter, so C libraries like NumPy can be used. It also works with PyPy 2.3+. +Note that support for Python 2.6 is deprecated as of Spark 2.0.0, and may be removed in Spark 2.2.0. + To run Spark applications in Python, use the `bin/spark-submit` script located in the Spark directory. This script will load Spark's Java/Scala libraries and allow you to submit applications to a cluster. You can also use `bin/pyspark` to launch an interactive Python shell. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 77b06fcf3374..923d8dbebf3d 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -506,8 +506,13 @@ See the [configuration page](configuration.html) for information on Spark config since this configuration is just a upper limit and not a guaranteed amount. - - + + spark.mesos.fetcherCache.enable + false + + If set to `true`, all URIs (example: `spark.executor.uri`, `spark.mesos.uris`) will be cached by the [Mesos fetcher cache](http://mesos.apache.org/documentation/latest/fetcher/) + + # Troubleshooting and Debugging diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index c1ef396907db..b645d3c3a4b5 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -17,69 +17,72 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea
- import org.apache.kafka.clients.consumer.ConsumerRecord - import org.apache.kafka.common.serialization.StringDeserializer - import org.apache.spark.streaming.kafka010._ - import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent - import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe - - val kafkaParams = Map[String, Object]( - "bootstrap.servers" -> "localhost:9092,anotherhost:9092", - "key.deserializer" -> classOf[StringDeserializer], - "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "use_a_separate_group_id_for_each_stream", - "auto.offset.reset" -> "latest", - "enable.auto.commit" -> (false: java.lang.Boolean) - ) - - val topics = Array("topicA", "topicB") - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Subscribe[String, String](topics, kafkaParams) - ) - - stream.map(record => (record.key, record.value)) - +{% highlight scala %} +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.spark.streaming.kafka010._ +import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent +import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + +val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "use_a_separate_group_id_for_each_stream", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) +) + +val topics = Array("topicA", "topicB") +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) +) + +stream.map(record => (record.key, record.value)) +{% endhighlight %} Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
- import java.util.*; - import org.apache.spark.SparkConf; - import org.apache.spark.TaskContext; - import org.apache.spark.api.java.*; - import org.apache.spark.api.java.function.*; - import org.apache.spark.streaming.api.java.*; - import org.apache.spark.streaming.kafka010.*; - import org.apache.kafka.clients.consumer.ConsumerRecord; - import org.apache.kafka.common.TopicPartition; - import org.apache.kafka.common.serialization.StringDeserializer; - import scala.Tuple2; - - Map kafkaParams = new HashMap<>(); - kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); - kafkaParams.put("key.deserializer", StringDeserializer.class); - kafkaParams.put("value.deserializer", StringDeserializer.class); - kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); - kafkaParams.put("auto.offset.reset", "latest"); - kafkaParams.put("enable.auto.commit", false); - - Collection topics = Arrays.asList("topicA", "topicB"); - - final JavaInputDStream> stream = - KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Subscribe(topics, kafkaParams) - ); - - stream.mapToPair( - new PairFunction, String, String>() { - @Override - public Tuple2 call(ConsumerRecord record) { - return new Tuple2<>(record.key(), record.value()); - } - }) +{% highlight java %} +import java.util.*; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka010.*; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.StringDeserializer; +import scala.Tuple2; + +Map kafkaParams = new HashMap<>(); +kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); +kafkaParams.put("key.deserializer", StringDeserializer.class); +kafkaParams.put("value.deserializer", StringDeserializer.class); +kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); +kafkaParams.put("auto.offset.reset", "latest"); +kafkaParams.put("enable.auto.commit", false); + +Collection topics = Arrays.asList("topicA", "topicB"); + +final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + +stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + }) +{% endhighlight %}
@@ -109,32 +112,35 @@ If you have a use case that is better suited to batch processing, you can create
- // Import dependencies and create kafka params as in Create Direct Stream above - - val offsetRanges = Array( - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange("test", 0, 0, 100), - OffsetRange("test", 1, 0, 100) - ) +{% highlight scala %} +// Import dependencies and create kafka params as in Create Direct Stream above - val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) +) +val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +{% endhighlight %}
- // Import dependencies and create kafka params as in Create Direct Stream above - - OffsetRange[] offsetRanges = { - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange.create("test", 0, 0, 100), - OffsetRange.create("test", 1, 0, 100) - }; - - JavaRDD> rdd = KafkaUtils.createRDD( - sparkContext, - kafkaParams, - offsetRanges, - LocationStrategies.PreferConsistent() - ); +{% highlight java %} +// Import dependencies and create kafka params as in Create Direct Stream above + +OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) +}; + +JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() +); +{% endhighlight %}
@@ -144,29 +150,33 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd.foreachPartition { iter => - val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - } +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } +} +{% endhighlight %}
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - rdd.foreachPartition(new VoidFunction>>() { - @Override - public void call(Iterator> consumerRecords) { - OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); - } - }); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } +}); +{% endhighlight %}
@@ -183,25 +193,28 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - - // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) - } - +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) +} +{% endhighlight %} As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - // some time later, after outputs have completed - ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } +}); +{% endhighlight %}
@@ -210,64 +223,68 @@ For data stores that support transactions, saving offsets in the same transactio
- // The details depend on your data store, but the general idea looks like this +{% highlight scala %} +// The details depend on your data store, but the general idea looks like this - // begin from the the offsets committed to the database - val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") - }.toMap +// begin from the the offsets committed to the database +val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") +}.toMap - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) - ) +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) +) - stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - val results = yourCalculation(rdd) + val results = yourCalculation(rdd) - // begin your transaction + // begin your transaction - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // end your transaction - } + // end your transaction +} +{% endhighlight %}
- // The details depend on your data store, but the general idea looks like this - - // begin from the the offsets committed to the database - Map fromOffsets = new HashMap<>(); - for (resultSet : selectOffsetsFromYourDatabase) - fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); - } - - JavaInputDStream> stream = KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) - ); - - stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - Object results = yourCalculation(rdd); - - // begin your transaction - - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly - - // end your transaction - } - }); +{% highlight java %} +// The details depend on your data store, but the general idea looks like this + +// begin from the the offsets committed to the database +Map fromOffsets = new HashMap<>(); +for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); +} + +JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) +); + +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } +}); +{% endhighlight %}
@@ -277,25 +294,29 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html
- val kafkaParams = Map[String, Object]( - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - "security.protocol" -> "SSL", - "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", - "ssl.truststore.password" -> "test1234", - "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", - "ssl.keystore.password" -> "test1234", - "ssl.key.password" -> "test1234" - ) +{% highlight scala %} +val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" +) +{% endhighlight %}
- Map kafkaParams = new HashMap(); - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - kafkaParams.put("security.protocol", "SSL"); - kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); - kafkaParams.put("ssl.truststore.password", "test1234"); - kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); - kafkaParams.put("ssl.keystore.password", "test1234"); - kafkaParams.put("ssl.key.password", "test1234"); +{% highlight java %} +Map kafkaParams = new HashMap(); +// the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS +kafkaParams.put("security.protocol", "SSL"); +kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); +kafkaParams.put("ssl.truststore.password", "test1234"); +kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); +kafkaParams.put("ssl.keystore.password", "test1234"); +kafkaParams.put("ssl.key.password", "test1234"); +{% endhighlight %}
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index a6c3b3a9024d..c4c9fb3f7d3d 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -19,97 +19,103 @@ application. See the [Deploying](#deploying) subsection below.
+{% highlight scala %} - // Subscribe to 1 topic - val ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to 1 topic +val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to multiple topics - val ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to multiple topics +val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to a pattern - val ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to a pattern +val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] +{% endhighlight %}
+{% highlight java %} - // Subscribe to 1 topic - Dataset ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to 1 topic +Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to multiple topics - Dataset ds2 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to multiple topics +Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to a pattern - Dataset ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to a pattern +Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
+{% highlight python %} - # Subscribe to 1 topic - ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic +ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to multiple topics - ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to multiple topics +ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to a pattern - ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to a pattern +ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 173fd6e8c73b..d838ed35a14f 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -14,10 +14,8 @@ Structured Streaming is a scalable and fault-tolerant stream processing engine b # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/ -[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/ -[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you -[download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark. +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). +And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
@@ -409,16 +407,15 @@ Delivering end-to-end exactly-once semantics was one of key goals behind the des to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames -Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` ([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/ -[Java](api/java/org/apache/spark/sql/SparkSession.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the +Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` +([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) +to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the [DataFrame/Dataset Programming Guide](sql-programming-guide.html). ## Creating streaming DataFrames and streaming Datasets Streaming DataFrames can be created through the `DataStreamReader` interface -([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/ -[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) +returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. #### Data Sources In Spark 2.0, there are a few built-in sources. @@ -628,9 +625,7 @@ The result tables would look something like the following. ![Window Operations](img/structured-streaming-window.png) Since this windowing is similar to grouping, in code, you can use `groupBy()` and `window()` operations to express windowed aggregations. You can see the full code for the below examples in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/ -[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/ -[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py). +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py).
@@ -753,10 +748,9 @@ In addition, there are some Dataset methods that will not work on streaming Data If you try any of these operations, you will see an AnalysisException like "operation XYZ is not supported with streaming DataFrames/Datasets". ## Starting Streaming Queries -Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the -`DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/ -[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. +Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) +returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. - *Details of the output sink:* Data format, location, etc. @@ -953,8 +947,9 @@ spark.sql("select * from aggregates").show() # interactively query in-memory t
#### Using Foreach -The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.0, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/ -[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. +The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.0, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` +([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), +which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. - The writer must be serializable, as it will be serialized and sent to the executors for execution. @@ -1046,9 +1041,9 @@ query.sinkStatus() # progress information about data written to the output sin
-You can start any number of queries in a single SparkSession. They will all be running concurrently sharing the cluster resources. You can use `sparkSession.streams()` to get the `StreamingQueryManager` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryManager)/ -[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryManager.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryManager) docs) that can be used to manage the currently active queries. +You can start any number of queries in a single SparkSession. They will all be running concurrently sharing the cluster resources. You can use `sparkSession.streams()` to get the `StreamingQueryManager` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryManager)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryManager.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryManager) docs) +that can be used to manage the currently active queries.
@@ -1092,8 +1087,9 @@ spark.streams().awaitAnyTermination() # block until any one of them terminates
-Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/ -[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), which will give you regular callback-based updates when queries are started and terminated. +Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), +which will give you regular callback-based updates when queries are started and terminated. ## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 7df145e3117b..89855e81f1f7 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -54,7 +54,7 @@ public static void main(String[] args) throws Exception { public Integer call(Integer integer) { double x = Math.random() * 2 - 1; double y = Math.random() * 2 - 1; - return (x * x + y * y < 1) ? 1 : 0; + return (x * x + y * y <= 1) ? 1 : 0; } }).reduce(new Function2() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java new file mode 100644 index 000000000000..4213c05703cc --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.ml.feature.Interaction; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.Arrays; +import java.util.List; + +// $example on$ +// $example off$ + +public class JavaInteractionExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaInteractionExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1, 1, 2, 3, 8, 4, 5), + RowFactory.create(2, 4, 3, 8, 7, 9, 8), + RowFactory.create(3, 6, 1, 9, 2, 3, 6), + RowFactory.create(4, 10, 8, 6, 9, 4, 5), + RowFactory.create(5, 9, 2, 7, 10, 7, 3), + RowFactory.create(6, 1, 1, 4, 2, 8, 4) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id1", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id2", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id3", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id4", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id5", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id6", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id7", DataTypes.IntegerType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + VectorAssembler assembler1 = new VectorAssembler() + .setInputCols(new String[]{"id2", "id3", "id4"}) + .setOutputCol("vec1"); + + Dataset assembled1 = assembler1.transform(df); + + VectorAssembler assembler2 = new VectorAssembler() + .setInputCols(new String[]{"id5", "id6", "id7"}) + .setOutputCol("vec2"); + + Dataset assembled2 = assembler2.transform(assembled1).select("id1", "vec1", "vec2"); + + Interaction interaction = new Interaction() + .setInputCols(new String[]{"id1","vec1","vec2"}) + .setOutputCol("interactedCol"); + + Dataset interacted = interaction.transform(assembled2); + + interacted.show(false); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index e3f0c4aeef1b..37029b76798f 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -38,7 +38,7 @@ def f(_): x = random() * 2 - 1 y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 + return 1 if x ** 2 + y ** 2 <= 1 else 0 count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index 720d92fb9d02..121b768e4198 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -26,7 +26,7 @@ object LocalPi { for (i <- 1 to 100000) { val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 + if (x*x + y*y <= 1) count += 1 } println("Pi is roughly " + 4 * count / 100000.0) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 272c1a4fc2f4..a5cacf17a5cc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -34,7 +34,7 @@ object SparkPi { val count = spark.sparkContext.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 + if (x*x + y*y <= 1) 1 else 0 }.reduce(_ + _) println("Pi is roughly " + 4.0 * count / (n - 1)) spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/InteractionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/InteractionExample.scala new file mode 100644 index 000000000000..8113c992b1d6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/InteractionExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Interaction +import org.apache.spark.ml.feature.VectorAssembler +// $example off$ +import org.apache.spark.sql.SparkSession + +object InteractionExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("InteractionExample") + .getOrCreate() + + // $example on$ + val df = spark.createDataFrame(Seq( + (1, 1, 2, 3, 8, 4, 5), + (2, 4, 3, 8, 7, 9, 8), + (3, 6, 1, 9, 2, 3, 6), + (4, 10, 8, 6, 9, 4, 5), + (5, 9, 2, 7, 10, 7, 3), + (6, 1, 1, 4, 2, 8, 4) + )).toDF("id1", "id2", "id3", "id4", "id5", "id6", "id7") + + val assembler1 = new VectorAssembler(). + setInputCols(Array("id2", "id3", "id4")). + setOutputCol("vec1") + + val assembled1 = assembler1.transform(df) + + val assembler2 = new VectorAssembler(). + setInputCols(Array("id5", "id6", "id7")). + setOutputCol("vec2") + + val assembled2 = assembler2.transform(assembled1).select("id1", "vec1", "vec2") + + val interaction = new Interaction() + .setInputCols(Array("id1", "vec1", "vec2")) + .setOutputCol("interactedCol") + + val interacted = interaction.transform(assembled2) + + interacted.show(truncate = false) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index cfedb5a042a3..36ba0bda528d 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -29,7 +29,7 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { SparkConf conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("test") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 1c93079497f6..cd6e4b78ec38 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -43,7 +43,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfterAll with @transient private var _sc: SparkContext = _ val conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 40d568a12c25..13d717092a89 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.kafka010 -import java.io.Writer - import scala.collection.mutable.HashMap import scala.util.control.NonFatal diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 61cba737d148..5bcc5124b091 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -88,7 +90,10 @@ private[kafka010] case class KafkaSource( private val sc = sqlContext.sparkContext - private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + private val pollTimeoutMs = sourceOptions.getOrElse( + "kafkaConsumer.pollTimeoutMs", + sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + ).toLong private val maxOffsetFetchAttempts = sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -111,7 +116,22 @@ private[kafka010] case class KafkaSource( * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ private lazy val initialPartitionOffsets = { - val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) + val metadataLog = + new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + val bytes = metadata.json.getBytes(StandardCharsets.UTF_8) + out.write(bytes.length) + out.write(bytes) + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + val length = in.read() + val bytes = new Array[Byte](length) + in.read(bytes) + KafkaSourceOffset(SerializedOffset(new String(bytes, StandardCharsets.UTF_8))) + } + } + metadataLog.get(0).getOrElse { val offsets = startingOffsets match { case EarliestOffsets => KafkaSourceOffset(fetchEarliestOffsets()) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5ade982515f..b5da415b3097 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition -import org.apache.spark.sql.execution.streaming.Offset +import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and @@ -27,9 +27,8 @@ import org.apache.spark.sql.execution.streaming.Offset */ private[kafka010] case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { - override def toString(): String = { - partitionToOffsets.toSeq.sortBy(_._1.toString).mkString("[", ", ", "]") - } + + override val json = JsonUtils.partitionOffsets(partitionToOffsets) } /** Companion object of the [[KafkaSourceOffset]] */ @@ -38,6 +37,7 @@ private[kafka010] object KafkaSourceOffset { def getPartitionOffsets(offset: Offset): Map[TopicPartition, Long] = { offset match { case o: KafkaSourceOffset => o.partitionToOffsets + case so: SerializedOffset => KafkaSourceOffset(so).partitionToOffsets case _ => throw new IllegalArgumentException( s"Invalid conversion from offset of ${offset.getClass} to KafkaSourceOffset") @@ -51,4 +51,10 @@ private[kafka010] object KafkaSourceOffset { def apply(offsetTuples: (String, Int, Long)*): KafkaSourceOffset = { KafkaSourceOffset(offsetTuples.map { case(t, p, o) => (new TopicPartition(t, p), o) }.toMap) } + + /** + * Returns [[KafkaSourceOffset]] from a JSON [[SerializedOffset]] + */ + def apply(offset: SerializedOffset): KafkaSourceOffset = + KafkaSourceOffset(JsonUtils.partitionOffsets(offset.json)) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala index 7056a41b1751..881018fd9566 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.kafka010 +import java.io.File + +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.streaming.OffsetSuite +import org.apache.spark.sql.test.SharedSQLContext -class KafkaSourceOffsetSuite extends OffsetSuite { +class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext { compare( one = KafkaSourceOffset(("t", 0, 1L)), @@ -36,4 +40,53 @@ class KafkaSourceOffsetSuite extends OffsetSuite { compare( one = KafkaSourceOffset(("t", 0, 1L)), two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) + + + val kso1 = KafkaSourceOffset(("t", 0, 1L)) + val kso2 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L)) + val kso3 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L), ("t", 1, 4L)) + + compare(KafkaSourceOffset(SerializedOffset(kso1.json)), + KafkaSourceOffset(SerializedOffset(kso2.json))) + + test("basic serialization - deserialization") { + assert(KafkaSourceOffset.getPartitionOffsets(kso1) == + KafkaSourceOffset.getPartitionOffsets(SerializedOffset(kso1.json))) + } + + + testWithUninterruptibleThread("OffsetSeqLog serialization - deserialization") { + withTempDir { temp => + // use non-existent directory to test whether log make the dir + val dir = new File(temp, "dir") + val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath) + val batch0 = OffsetSeq.fill(kso1) + val batch1 = OffsetSeq.fill(kso2, kso3) + + val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o => + SerializedOffset(o.json))): _*) + + val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o => + SerializedOffset(o.json))): _*) + + assert(metadataLog.add(0, batch0)) + assert(metadataLog.getLatest() === Some(0 -> batch0Serialized)) + assert(metadataLog.get(0) === Some(batch0Serialized)) + + assert(metadataLog.add(1, batch1)) + assert(metadataLog.get(0) === Some(batch0Serialized)) + assert(metadataLog.get(1) === Some(batch1Serialized)) + assert(metadataLog.getLatest() === Some(1 -> batch1Serialized)) + assert(metadataLog.get(None, Some(1)) === + Array(0 -> batch0Serialized, 1 -> batch1Serialized)) + + // Adding the same batch does nothing + metadataLog.add(1, OffsetSeq.fill(LongOffset(3))) + assert(metadataLog.get(0) === Some(batch0Serialized)) + assert(metadataLog.get(1) === Some(batch1Serialized)) + assert(metadataLog.getLatest() === Some(1 -> batch1Serialized)) + assert(metadataLog.get(None, Some(1)) === + Array(0 -> batch0Serialized, 1 -> batch1Serialized)) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index ed4cc75920e8..89e713f92df4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -306,6 +306,30 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("starting offset is latest by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("0")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val kafka = reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + val mapped = kafka.map(_.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(1, 2, 3) // should not have 0 + ) + } + test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 5b5a9ac48c7c..98394251bb23 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -66,7 +66,8 @@ private[spark] class KafkaRDD[K, V]( " must be set to false for executor kafka params, else offsets may commit before processing") // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? - private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", 512) + private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", + conf.getTimeAsMs("spark.network.timeout", "120s")) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 73b6ca384438..7d6693b4cdf5 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging @@ -51,7 +52,7 @@ private[mesos] class MesosClusterDispatcher( extends Logging { private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) - private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase() + private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) private val engineFactory = recoveryMode match { @@ -74,7 +75,7 @@ private[mesos] class MesosClusterDispatcher( def start(): Unit = { webUi.bind() - scheduler.frameworkUrl = conf.get("spark.mesos.dispatcher.webui.url", webUi.activeWebUiUrl) + scheduler.frameworkUrl = conf.get(DISPATCHER_WEBUI_URL).getOrElse(webUi.activeWebUiUrl) scheduler.start() server.start() } @@ -99,8 +100,8 @@ private[mesos] object MesosClusterDispatcher extends Logging { conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => - conf.set("spark.deploy.recoveryMode", "ZOOKEEPER") - conf.set("spark.deploy.zookeeper.url", z) + conf.set(RECOVERY_MODE, "ZOOKEEPER") + conf.set(ZOOKEEPER_URL, z) } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 6b297c4600a6..859aa836a315 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler @@ -114,7 +115,7 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage protected override def newShuffleBlockHandler( conf: TransportConf): ExternalShuffleBlockHandler = { - val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s") + val cleanerIntervalS = this.conf.get(SHUFFLE_CLEANER_INTERVAL_S) new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS) } } diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala new file mode 100644 index 000000000000..19e253394f1b --- /dev/null +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config.ConfigBuilder + +package object config { + + /* Common app configuration. */ + + private[spark] val SHUFFLE_CLEANER_INTERVAL_S = + ConfigBuilder("spark.shuffle.cleaner.interval") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("30s") + + private[spark] val RECOVERY_MODE = + ConfigBuilder("spark.deploy.recoveryMode") + .stringConf + .createWithDefault("NONE") + + private[spark] val DISPATCHER_WEBUI_URL = + ConfigBuilder("spark.mesos.dispatcher.webui.url") + .doc("Set the Spark Mesos dispatcher webui_url for interacting with the " + + "framework. If unset it will point to Spark's internal web UI.") + .stringConf + .createOptional + + private[spark] val ZOOKEEPER_URL = + ConfigBuilder("spark.deploy.zookeeper.url") + .doc("When `spark.deploy.recoveryMode` is set to ZOOKEEPER, this " + + "configuration is used to set the zookeeper URL to connect to.") + .stringConf + .createOptional + + private[spark] val HISTORY_SERVER_URL = + ConfigBuilder("spark.mesos.dispatcher.historyServer.url") + .doc("Set the URL of the history server. The dispatcher will then " + + "link each driver to its entry in the history server.") + .stringConf + .createOptional + +} diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 8dcbdaad8685..13ba7d311e57 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -23,12 +23,13 @@ import scala.xml.Node import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState import org.apache.spark.ui.{UIUtils, WebUIPage} private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { - private val historyServerURL = parent.conf.getOption("spark.mesos.dispatcher.historyServer.url") + private val historyServerURL = parent.conf.get(HISTORY_SERVER_URL) def render(request: HttpServletRequest): Seq[Node] = { val state = parent.scheduler.getSchedulerState() diff --git a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 3b96488a129a..ff60b88c6d53 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest.mesos import java.io.File import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse @@ -62,11 +62,10 @@ private[mesos] class MesosSubmitRequestServlet( private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private def newDriverId(submitDate: Date): String = { - "driver-%s-%04d".format( - createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) - } + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + private def newDriverId(submitDate: Date): String = + f"driver-${createDateFormat.format(submitDate)}-${nextDriverNumber.incrementAndGet()}%04d" /** * Build a driver description from the fields specified in the submit request. diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 0b454997772d..8db1d126d59b 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -129,6 +129,7 @@ private[spark] class MesosClusterScheduler( private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val useFetchCache = conf.getBoolean("spark.mesos.fetchCache.enable", false) private val schedulerState = engineFactory.createEngine("scheduler") private val stateLock = new Object() private val finishedDrivers = @@ -396,7 +397,7 @@ private[spark] class MesosClusterScheduler( val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).build()) + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) } private def getDriverCommandValue(desc: MesosDriverDescription): String = { @@ -481,7 +482,7 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s"$key=${shellEscape(value)}") } + options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } options } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index e67bf3e328f9..842c05e7bf73 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -59,6 +59,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) + val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) private[this] val shutdownTimeoutMS = @@ -156,7 +158,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), None, None, sc.conf.getOption("spark.mesos.driver.frameworkId") @@ -226,10 +228,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get).setCache(useFetcherCache)) } - conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) + conf.getOption("spark.mesos.uris").foreach(setupUris(_, command, useFetcherCache)) command.build() } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 09a252f3c74a..c1aa00151e69 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -77,7 +77,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), Option.empty, Option.empty, sc.conf.getOption("spark.mesos.driver.frameworkId") diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 73cc241239c4..9cb60237044a 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -369,9 +369,11 @@ trait MesosSchedulerUtils extends Logging { sc.executorMemory } - def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + def setupUris(uris: String, + builder: CommandInfo.Builder, + useFetcherCache: Boolean = false): Unit = { uris.split(",").foreach { uri => - builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetcherCache)) } } diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 75ba02e470e2..f73638fda623 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -463,6 +463,34 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getCommand.getUrisList.asScala(0).getValue == url) } + test("mesos supports setting fetcher cache") { + val url = "spark.spark.spark.com" + setBackend(Map( + "spark.mesos.fetcherCache.enable" -> "true", + "spark.executor.uri" -> url + ), false) + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + val uris = launchedTasks.head.getCommand.getUrisList + assert(uris.size() == 1) + assert(uris.asScala.head.getCache) + } + + test("mesos supports disabling fetcher cache") { + val url = "spark.spark.spark.com" + setBackend(Map( + "spark.mesos.fetcherCache.enable" -> "false", + "spark.executor.uri" -> url + ), false) + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + val uris = launchedTasks.head.getCommand.getUrisList + assert(uris.size() == 1) + assert(!uris.asScala.head.getCache) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def verifyDeclinedOffer(driver: SchedulerDriver, diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 195a93e08672..f406f8c426d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") ( override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) - new Pipeline().setStages(newStages) + new Pipeline(uid).setStages(newStages) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6..aa92edde7acd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -87,7 +88,12 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,7 +127,7 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f234..a3da3067e1b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda032..f8f164e8c14b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42e..c4651054fd76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99d..b03a07a6bc1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 2718dd93dcb5..f8a606d60b2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { - val copied = new BisectingKMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 8fac63fefbb5..a0bd66e731a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { - val copied = new GaussianMixtureModel(uid, weights, gaussians) - copyValues(copied, extra).setParent(this.parent) + val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 85bb8c93b3fa..a0d481b294ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -108,8 +108,9 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { - val copied = new KMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new KMeansModel(uid, parentModel), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index d0385e220e1e..653fa41124f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params with HasFeaturesCol with HasOutputCol with HasLabelCol { /** - * Number of features that selector will select (ordered by statistic value descending). If the + * Number of features that selector will select, ordered by ascending p-value. If the * number of features is less than numTopFeatures, then this will select all features. - * Only applicable when selectorType = "kbest". + * Only applicable when selectorType = "numTopFeatures". * The default value of numTopFeatures is 50. * * @group param */ + @Since("1.6.0") final val numTopFeatures = new IntParam(this, "numTopFeatures", - "Number of features that selector will select, ordered by statistics value descending. If the" + + "Number of features that selector will select, ordered by ascending p-value. If the" + " number of features is < numTopFeatures, then this will select all features.", ParamValidators.gtEq(1)) setDefault(numTopFeatures -> 50) /** @group getParam */ + @Since("1.6.0") def getNumTopFeatures: Int = $(numTopFeatures) /** * Percentile of features that selector will select, ordered by statistics value descending. * Only applicable when selectorType = "percentile". * Default value is 0.1. + * @group param */ + @Since("2.1.0") final val percentile = new DoubleParam(this, "percentile", - "Percentile of features that selector will select, ordered by statistics value descending.", + "Percentile of features that selector will select, ordered by ascending p-value.", ParamValidators.inRange(0, 1)) setDefault(percentile -> 0.1) /** @group getParam */ + @Since("2.1.0") def getPercentile: Double = $(percentile) /** * The highest p-value for features to be kept. * Only applicable when selectorType = "fpr". * Default value is 0.05. + * @group param */ - final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", + final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) - setDefault(alpha -> 0.05) + setDefault(fpr -> 0.05) /** @group getParam */ - def getAlpha: Double = $(alpha) + def getFpr: Double = $(fpr) /** * The selector type of the ChisqSelector. - * Supported options: "kbest" (default), "percentile" and "fpr". + * Supported options: "numTopFeatures" (default), "percentile", "fpr". + * @group param */ + @Since("2.1.0") final val selectorType = new Param[String](this, "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", - ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) - setDefault(selectorType -> OldChiSqSelector.KBest) + "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "), + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes)) + setDefault(selectorType -> OldChiSqSelector.NumTopFeatures) /** @group getParam */ + @Since("2.1.0") def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -113,10 +124,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) - /** @group setParam */ - @Since("2.1.0") - def setSelectorType(value: String): this.type = set(selectorType, value) - /** @group setParam */ @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) @@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = set(alpha, value) + def setFpr(value: Double): this.type = set(fpr, value) + + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) /** @group setParam */ @Since("1.6.0") @@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str .setSelectorType($(selectorType)) .setNumTopFeatures($(numTopFeatures)) .setPercentile($(percentile)) - .setAlpha($(alpha)) + .setFpr($(fpr)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) - otherPairs.foreach { case (_, paramName: String) => + val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType)) + otherPairs.foreach { paramName: String => if (isSet(getParam(paramName))) { logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala index 2f5299b01022..96fd0d18b5ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala @@ -16,9 +16,10 @@ */ package org.apache.spark.ml.optim +import scala.collection.mutable + import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import scala.collection.mutable import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors} import org.apache.spark.mllib.linalg.CholeskyDecomposition @@ -57,7 +58,7 @@ private[ml] sealed trait NormalEquationSolver { */ private[ml] class CholeskySolver extends NormalEquationSolver { - def solve( + override def solve( bBar: Double, bbBar: Double, abBar: DenseVector, @@ -80,7 +81,7 @@ private[ml] class QuasiNewtonSolver( tol: Double, l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver { - def solve( + override def solve( bBar: Double, bbBar: Double, abBar: DenseVector, @@ -156,7 +157,7 @@ private[ml] class QuasiNewtonSolver( * Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible * (singular). */ -class SingularMatrixException(message: String, cause: Throwable) +private[spark] class SingularMatrixException(message: String, cause: Throwable) extends IllegalArgumentException(message, cause) { def this(message: String) = this(message, null) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 90c24e1b590e..56ab9675700a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -47,7 +47,7 @@ private[ml] class WeightedLeastSquaresModel( * formulation: * * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,, - * + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^ + * + lambda / delta (1/2 (1 - alpha) sum,,j,, (sigma,,j,, x,,j,,)^2^ * + alpha sum,,j,, abs(sigma,,j,, x,,j,,)), * * where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter, @@ -91,7 +91,7 @@ private[ml] class WeightedLeastSquares( require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") - require(tol > 0, s"tol must be greater than zero: $tol") + require(tol >= 0.0, s"tol must be >= 0, but was set to $tol") /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala new file mode 100644 index 000000000000..894602503220 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class GBTClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + import GBTClassifierWrapper._ + + private val gbtcModel: GBTClassificationModel = + pipeline.stages(1).asInstanceOf[GBTClassificationModel] + + lazy val numFeatures: Int = gbtcModel.numFeatures + lazy val featureImportances: Vector = gbtcModel.featureImportances + lazy val numTrees: Int = gbtcModel.getNumTrees + lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + + def summary: String = gbtcModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(gbtcModel.getFeaturesCol) + } + + override def write: MLWriter = new + GBTClassifierWrapper.GBTClassifierWrapperWriter(this) +} + +private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + maxIter: Int, + stepSize: Double, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + lossType: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): GBTClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setForceIndexLabel(true) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // get label names from output schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + + // assemble and fit the pipeline + val rfc = new GBTClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setMaxIter(maxIter) + .setStepSize(stepSize) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setLossType(lossType) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc, idxToStr)) + .fit(data) + + new GBTClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[GBTClassifierWrapper] = new GBTClassifierWrapperReader + + override def load(path: String): GBTClassifierWrapper = super.load(path) + + class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] { + + override def load(path: String): GBTClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new GBTClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala new file mode 100644 index 000000000000..585077588eb9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class GBTRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val gbtrModel: GBTRegressionModel = + pipeline.stages(1).asInstanceOf[GBTRegressionModel] + + lazy val numFeatures: Int = gbtrModel.numFeatures + lazy val featureImportances: Vector = gbtrModel.featureImportances + lazy val numTrees: Int = gbtrModel.getNumTrees + lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + + def summary: String = gbtrModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(gbtrModel.getFeaturesCol) + } + + override def write: MLWriter = new + GBTRegressorWrapper.GBTRegressorWrapperWriter(this) +} + +private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + maxIter: Int, + stepSize: Double, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + lossType: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): GBTRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new GBTRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setMaxIter(maxIter) + .setStepSize(stepSize) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setLossType(lossType) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new GBTRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[GBTRegressorWrapper] = new GBTRegressorWrapperReader + + override def load(path: String): GBTRegressorWrapper = super.load(path) + + class GBTRegressorWrapperWriter(instance: GBTRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] { + + override def load(path: String): GBTRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new GBTRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index b1bb577e1ffe..995b1ef03bce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -23,11 +23,16 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.regression._ +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ private[r] class GeneralizedLinearRegressionWrapper private ( val pipeline: PipelineModel, @@ -42,6 +47,8 @@ private[r] class GeneralizedLinearRegressionWrapper private ( val rNumIterations: Int, val isLoaded: Boolean = false) extends MLWritable { + import GeneralizedLinearRegressionWrapper._ + private val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] @@ -52,7 +59,15 @@ private[r] class GeneralizedLinearRegressionWrapper private ( def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType) def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(glm.getFeaturesCol) + if (rFamily == "binomial") { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_PROB_COL) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(glm.getFeaturesCol) + } else { + pipeline.transform(dataset) + .drop(glm.getFeaturesCol) + } } override def write: MLWriter = @@ -62,6 +77,10 @@ private[r] class GeneralizedLinearRegressionWrapper private ( private[r] object GeneralizedLinearRegressionWrapper extends MLReadable[GeneralizedLinearRegressionWrapper] { + val PREDICTED_LABEL_PROB_COL = "pred_label_prob" + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( formula: String, data: DataFrame, @@ -71,8 +90,8 @@ private[r] object GeneralizedLinearRegressionWrapper maxIter: Int, weightCol: String, regParam: Double): GeneralizedLinearRegressionWrapper = { - val rFormula = new RFormula() - .setFormula(formula) + val rFormula = new RFormula().setFormula(formula) + if (family == "binomial") rFormula.setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema @@ -90,9 +109,27 @@ private[r] object GeneralizedLinearRegressionWrapper .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, glr)) - .fit(data) + val pipeline = if (family == "binomial") { + // Convert prediction from probability to label index. + val probToPred = new ProbabilityToPrediction() + .setInputCol(PREDICTED_LABEL_PROB_COL) + .setOutputCol(PREDICTED_LABEL_INDEX_COL) + // Convert prediction from label index to original label. + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + new Pipeline() + .setStages(Array(rFormulaModel, glr.setPredictionCol(PREDICTED_LABEL_PROB_COL), + probToPred, idxToStr)) + .fit(data) + } else { + new Pipeline().setStages(Array(rFormulaModel, glr)).fit(data) + } val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] @@ -200,3 +237,27 @@ private[r] object GeneralizedLinearRegressionWrapper } } } + +/** + * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel + * with "binomial" family from probability to prediction according to threshold 0.5. + */ +private[r] class ProbabilityToPrediction private[r] (override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("probToPred")) + + def setInputCol(value: String): this.type = set(inputCol, value) + + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields :+ StructField($(outputCol), DoubleType)) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + dataset.withColumn($(outputCol), round(col($(inputCol)))) + } + + override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 1df3662a5822..b59fe292349b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -56,6 +56,14 @@ private[r] object RWrappers extends MLReader[Object] { ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => + RandomForestRegressorWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => + RandomForestClassifierWrapper.load(path) + case "org.apache.spark.ml.r.GBTRegressorWrapper" => + GBTRegressorWrapper.load(path) + case "org.apache.spark.ml.r.GBTClassifierWrapper" => + GBTClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala new file mode 100644 index 000000000000..6947ba7e7597 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val rfcModel: RandomForestClassificationModel = + pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] + + lazy val numFeatures: Int = rfcModel.numFeatures + lazy val featureImportances: Vector = rfcModel.featureImportances + lazy val numTrees: Int = rfcModel.getNumTrees + lazy val treeWeights: Array[Double] = rfcModel.treeWeights + + def summary: String = rfcModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(rfcModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this) +} + +private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + probabilityCol: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfc = new RandomForestClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setProbabilityCol(probabilityCol) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc)) + .fit(data) + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestClassifierWrapper] = + new RandomForestClassifierWrapperReader + + override def load(path: String): RandomForestClassifierWrapper = super.load(path) + + class RandomForestClassifierWrapperWriter(instance: RandomForestClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestClassifierWrapperReader extends MLReader[RandomForestClassifierWrapper] { + + override def load(path: String): RandomForestClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala new file mode 100644 index 000000000000..4b9a3a731da9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val rfrModel: RandomForestRegressionModel = + pipeline.stages(1).asInstanceOf[RandomForestRegressionModel] + + lazy val numFeatures: Int = rfrModel.numFeatures + lazy val featureImportances: Vector = rfrModel.featureImportances + lazy val numTrees: Int = rfrModel.getNumTrees + lazy val treeWeights: Array[Double] = rfrModel.treeWeights + + def summary: String = rfrModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(rfrModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestRegressorWrapper.RandomForestRegressorWrapperWriter(this) +} + +private[r] object RandomForestRegressorWrapper extends MLReadable[RandomForestRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new RandomForestRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestRegressorWrapper] = new RandomForestRegressorWrapperReader + + override def load(path: String): RandomForestRegressorWrapper = super.load(path) + + class RandomForestRegressorWrapperWriter(instance: RandomForestRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestRegressorWrapperReader extends MLReader[RandomForestRegressorWrapper] { + + override def load(path: String): RandomForestRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f6..1938e8ecc513 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { - copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - .setParent(parent) + val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), + extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82d..9639b07496c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ -import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares} +import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -160,11 +160,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". - * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton - * optimization method. "normal" denotes using Normal Equation as an analytical - * solution to the linear regression problem. - * The default value is "auto" which means that the solver algorithm is - * selected automatically. + * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. + * - "normal" denotes using Normal Equation as an analytical solution to the linear regression + * problem. This solver is limited to [[LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER]]. + * - "auto" (default) means that the solver algorithm is selected automatically. + * The Normal Equations solver will be used when possible, but this will automatically fall + * back to iterative optimization methods when needed. * * @group setParam */ @@ -190,7 +192,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -404,6 +406,14 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") override def load(path: String): LinearRegression = super.load(path) + + /** + * When using [[LinearRegression.solver]] == "normal", the solver must limit the number of + * features to at most this number. The entire covariance matrix X^T^X will be collected + * to the driver. This limit helps prevent memory overflow errors. + */ + @Since("2.1.0") + val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5e9e6ff1a569..cb3ca1b6c4be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -41,17 +41,11 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString - } - private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { @@ -135,11 +129,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour dataSchema: StructType): OutputWriterFactory = { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new LibSVMOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new LibSVMOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".libsvm" + TextOutputWriter.getCompressionExtension(context) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0fdba1cb8814..5d1a39f7c16d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], validationMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 904000f50d0a..034e3625e8c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable { selectorType: String, numTopFeatures: Int, percentile: Double, - alpha: Double, + fpr: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { new ChiSqSelector() .setSelectorType(selectorType) .setNumTopFeatures(numTopFeatures) .setPercentile(percentile) - .setAlpha(alpha) + .setFpr(fpr) .fit(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f8276de4f23d..f9156b642785 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { Loader.checkSchema[Data](dataFrame.schema) val features = dataArray.rdd.map { - case Row(feature: Int) => (feature) + case Row(feature: Int) => feature }.collect() new ChiSqSelectorModel(features) @@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 - var alpha: Double = 0.05 - var selectorType = ChiSqSelector.KBest + var fpr: Double = 0.05 + var selectorType = ChiSqSelector.NumTopFeatures /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } @Since("2.1.0") - def setAlpha(value: Double): this.type = { - require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") - alpha = value + def setFpr(value: Double): this.type = { + require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]") + fpr = value this } @Since("2.1.0") def setSelectorType(value: String): this.type = { - require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + require(ChiSqSelector.supportedSelectorTypes.contains(value), s"ChiSqSelector Type: $value was not supported.") selectorType = value this @@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { - case ChiSqSelector.KBest => + case ChiSqSelector.NumTopFeatures => chiSqTestResult .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) @@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult - .filter { case (res, _) => res.pValue < alpha } + .filter { case (res, _) => res.pValue < fpr } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } @@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } -@Since("2.1.0") -object ChiSqSelector { +private[spark] object ChiSqSelector { - /** String name for `kbest` selector type. */ - private[spark] val KBest: String = "kbest" + /** String name for `numTopFeatures` selector type. */ + val NumTopFeatures: String = "numTopFeatures" /** String name for `percentile` selector type. */ - private[spark] val Percentile: String = "percentile" + val Percentile: String = "percentile" /** String name for `fpr` selector type. */ private[spark] val FPR: String = "fpr" - /** Set of selector type and param pairs that ChiSqSelector supports. */ - private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", - Percentile -> "percentile", FPR -> "alpha") - /** Set of selector types that ChiSqSelector supports. */ - private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) + val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 426bb818c926..f5ca1c221d66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.beans.BeanProperty @@ -34,7 +34,7 @@ private[mllib] trait PMMLModelExport { val version = getClass.getPackage.getImplementationVersion val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.US).format(new Date())) val header = new Header() .setApplication(app) .setTimestamp(timestamp) diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index e4f678fef1d1..a0ccc5632fa0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -29,13 +29,15 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession { File tempDir = null; + @Override public void setUp() throws IOException { super.setUp(); tempDir = Utils.createTempDir( - System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } + @Override public void tearDown() { super.tearDown(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 8c6bced52dd7..e57a7ce49857 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -43,7 +43,7 @@ public class JavaStreamingLogisticRegressionSuite { @Before public void setUp() { SparkConf conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("test") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index d41fc0e4dca9..52b5e7a958ea 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -42,7 +42,7 @@ public class JavaStreamingKMeansSuite { @Before public void setUp() { SparkConf conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("test") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index ab554475d59a..bfcdbd2edbd4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -42,7 +42,7 @@ public class JavaStreamingLinearRegressionSuite { @Before public void setUp() { SparkConf conf = new SparkConf() - .setMaster("local[2]") + .setMaster("local[4]") .setAppName("test") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 1abaa39eadc2..f5ae8df30ad0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -49,6 +49,8 @@ public class JavaStatisticsSuite { @Before public void setUp() { SparkConf conf = new SparkConf() + .setMaster("local[4]") + .setAppName("JavaStatistics") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); spark = SparkSession.builder() .master("local[2]") diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 6413ca1f8b19..dafc6c200f95 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } + test("Pipeline.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF)) + val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10)) + + assert(copied.uid === pipeline.uid, + "copy should create an instance with the same UID") + assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + test("PipelineModel.copy") { val hashingTF = new HashingTF() .setNumFeatures(100) - val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF)) + .setParent(new Pipeline()) val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) - require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + + assert(copied.uid === model.uid, + "copy should create an instance with the same UID") + assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, "copy should handle extra stage params") + assert(copied.parent === model.parent, + "copy should create an instance with the same parent") } test("pipeline model constructors") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 000000000000..03e0c536a973 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictor")) + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictormodel")) + + override def predict(features: Vector): Double = + throw new NotImplementedError() + + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d314..2877285eb4d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -141,6 +141,12 @@ class LogisticRegressionSuite assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) assert(model.hasParent) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("empty probabilityCol") { @@ -251,9 +257,6 @@ class LogisticRegressionSuite mlr.setFitIntercept(false) val mlrModel = mlr.fit(smallMultinomialDataset) assert(mlrModel.interceptVector === Vectors.sparse(3, Seq())) - - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { @@ -1807,7 +1810,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f2368a9f8dad..49797d938d75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -41,6 +42,13 @@ class BisectingKMeansSuite assert(bkm.getPredictionCol === "prediction") assert(bkm.getMaxIter === 20) assert(bkm.getMinDivisibleClusterSize === 1.0) + val model = bkm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("setter/getter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 003fa6abf659..7165b63ed3b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -43,6 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getPredictionCol === "prediction") assert(gm.getMaxIter === 100) assert(gm.getTol === 0.01) + val model = gm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ca392653557c..73972557d263 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -47,6 +48,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + val model = kmeans.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6af06d82d671..80970fd74488 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -19,85 +19,72 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("Test Chi-Square selector") { - import testImplicits._ - val data = Seq( - LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), - LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), - LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), - LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) - ) + @transient var dataset: Dataset[_] = _ - val preFilteredData = Seq( - Vectors.dense(8.0), - Vectors.dense(0.0), - Vectors.dense(0.0), - Vectors.dense(8.0) - ) + override def beforeAll(): Unit = { + super.beforeAll() - val df = sc.parallelize(data.zip(preFilteredData)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") - - val selector = new ChiSqSelector() - .setSelectorType("kbest") - .setNumTopFeatures(1) - .setFeaturesCol("data") - .setLabelCol("label") - .setOutputCol("filtered") - - selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } - - selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + // Toy dataset, including the top feature for a chi-squared test. + // These data are chosen such that each feature's test has a distinct p-value. + /* To verify the results with R, run: + library(stats) + x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0) + x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0) + x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0) + y <- c(0.0, 1.0, 1.0, 2.0, 2.0) + chisq.test(x1,y) + chisq.test(x2,y) + chisq.test(x3,y) + */ + dataset = spark.createDataFrame(Seq( + (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)), + (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)), + (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)), + (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)), + (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0)) + )).toDF("label", "features", "topFeature") + } - val preFilteredData2 = Seq( - Vectors.dense(8.0, 7.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(8.0, 9.0) - ) + test("params") { + ParamsSuite.checkParams(new ChiSqSelector) + val model = new ChiSqSelectorModel("myModel", + new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4))) + ParamsSuite.checkParams(model) + } - val df2 = sc.parallelize(data.zip(preFilteredData2)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") + test("Test Chi-Square selector: numTopFeatures") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) + ChiSqSelectorSuite.testSelector(selector, dataset) + } - selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + test("Test Chi-Square selector: percentile") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelector read/write") { - val t = new ChiSqSelector() - .setFeaturesCol("myFeaturesCol") - .setLabelCol("myLabelCol") - .setOutputCol("myOutputCol") - .setNumTopFeatures(2) - testDefaultReadWrite(t) + test("Test Chi-Square selector: fpr") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelectorModel read/write") { - val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) - val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) - val newInstance = testDefaultReadWrite(instance) - assert(newInstance.selectedFeatures === instance.selectedFeatures) + test("read/write") { + def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = { + assert(model.selectedFeatures === model2.selectedFeatures) + } + val nb = new ChiSqSelector + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { @@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext } } } + +object ChiSqSelectorSuite { + + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { + selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + .foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "selectorType" -> "percentile", + "numTopFeatures" -> 1, + "percentile" -> 0.12, + "outputCol" -> "myOutput" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ac1ef5feb95b..111bc974642d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ @@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c0e8afbf5e34..df97d0b2ae7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} @@ -143,6 +143,9 @@ class LinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 87100ae2e342..4463a9b6e543 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -78,6 +78,10 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setSeed(42L) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ac702b4b7c69..77219e500617 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } - test("ChiSqSelector by FPR transform test (sparse & dense vector)") { + test("ChiSqSelector by fpr transform test (sparse & dense vector)") { val labeledDiscreteData = sc.parallelize( Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) + val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr") + .setFpr(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } test("model load / save") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index e4e9be39ff6f..665708a780c4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -155,13 +155,17 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) - val lines = outputDir.listFiles() + val sources = outputDir.listFiles() .filter(_.getName.startsWith("part-")) - .flatMap(Source.fromFile(_).getLines()) - .toSet - val expected = Set("1.1 1:1.23 3:4.56", "0.0 1:1.01 2:2.02 3:3.03") - assert(lines === expected) - Utils.deleteRecursively(tempDir) + .map(Source.fromFile) + Utils.tryWithSafeFinally { + val lines = sources.flatMap(_.getLines()).toSet + val expected = Set("1.1 1:1.23 3:4.56", "0.0 1:1.01 2:2.02 3:3.03") + assert(lines === expected) + } { + sources.foreach(_.close()) + Utils.deleteRecursively(tempDir) + } } test("appendBias") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 6bb7ed9c9513..8157792a3460 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -34,7 +34,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => override def beforeAll() { super.beforeAll() spark = SparkSession.builder - .master("local[2]") + .master("local[4]") .appName("MLlibUnitTest") .getOrCreate() sc = spark.sparkContext diff --git a/pom.xml b/pom.xml index aaf7cfa7eb2a..04d2eaa1d3ba 100644 --- a/pom.xml +++ b/pom.xml @@ -2693,6 +2693,54 @@ + + + snapshots-and-staging + + + https://repository.apache.org/content/groups/staging/ + https://repository.apache.org/content/repositories/snapshots/ + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + +