diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7a89c01fee73..daee09de8826 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -44,7 +44,9 @@ exportMethods("glm", "spark.gaussianMixture", "spark.als", "spark.kstest", - "spark.logit") + "spark.logit", + "spark.randomForest", + "spark.gbt") # Job group lifecycle management methods export("setJobGroup", @@ -350,7 +352,11 @@ export("as.DataFrame", "uncacheTable", "print.summary.GeneralizedLinearRegressionModel", "read.ml", - "print.summary.KSTest") + "print.summary.KSTest", + "print.summary.RandomForestRegressionModel", + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -375,6 +381,10 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.RandomForestRegressionModel) +S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1df8bbf9fe60..1cf9b38ea648 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -788,7 +788,7 @@ setMethod("write.json", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "json", path)) + invisible(handledCallJMethod(write, "json", path)) }) #' Save the contents of SparkDataFrame as an ORC file, preserving the schema. @@ -819,7 +819,7 @@ setMethod("write.orc", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "orc", path)) + invisible(handledCallJMethod(write, "orc", path)) }) #' Save the contents of SparkDataFrame as a Parquet file, preserving the schema. @@ -851,7 +851,7 @@ setMethod("write.parquet", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "parquet", path)) + invisible(handledCallJMethod(write, "parquet", path)) }) #' @rdname write.parquet @@ -895,7 +895,7 @@ setMethod("write.text", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "text", path)) + invisible(handledCallJMethod(write, "text", path)) }) #' Distinct @@ -3342,7 +3342,7 @@ setMethod("write.jdbc", jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) #' randomSplit diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 216ca51666ba..38d83c6e5c52 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -350,7 +350,7 @@ read.json.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "json", paths) + sdf <- handledCallJMethod(read, "json", paths) dataFrame(sdf) } @@ -422,7 +422,7 @@ read.orc <- function(path, ...) { path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "orc", path) + sdf <- handledCallJMethod(read, "orc", path) dataFrame(sdf) } @@ -444,7 +444,7 @@ read.parquet.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "parquet", paths) + sdf <- handledCallJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -496,7 +496,7 @@ read.text.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "text", paths) + sdf <- handledCallJMethod(read, "text", paths) dataFrame(sdf) } @@ -914,12 +914,13 @@ read.jdbc <- function(url, tableName, } else { numPartitions <- numToInt(numPartitions) } - sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), - numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) } else if (length(predicates) > 0) { - sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), + jprops) } else { - sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, jprops) } dataFrame(sdf) } diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 03e70bb2cb82..0a789e6c379d 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) { conn <- get(".sparkRCon", .sparkREnv) writeBin(requestMessage, conn) - # TODO: check the status code to output error information returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + + # Backend will send +1 as keep alive value to prevent various connection timeouts + # on very long running jobs. See spark.r.heartBeatInterval + while (returnStatus == 1) { + returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + } + + readObject(conn) +} + +# Helper function to check for returned errors and print appropriate error message to user +handleErrors <- function(returnStatus, conn) { if (length(returnStatus) == 0) { stop("No status is returned. Java SparkR backend might have failed.") } - if (returnStatus != 0) { + + # 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors. + if (returnStatus < 0) { stop(readString(conn)) } - readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 2d341d836c13..9d82814211bc 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout = 6000) { +connectBackend <- function(hostname, port, timeout) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4d94b4cd05d4..f8a9d3ce5d91 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1485,7 +1485,7 @@ setMethod("soundex", #' Return the partition ID as a column #' -#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Return the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. #' @@ -2317,7 +2317,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_utc_timestamp #' -#' Assumes given timestamp is UTC and converts to given timezone. +#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp +#' that corresponds to the same time of day in the given timezone. #' #' @param y Column to compute on. #' @param x time zone to use. @@ -2340,7 +2341,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2391,7 +2392,8 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' to_utc_timestamp #' -#' Assumes given timestamp is in given timezone and converts to UTC. +#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns +#' another timestamp that corresponds to the same time of day in UTC. #' #' @param y Column to compute on #' @param x timezone to use @@ -2539,7 +2541,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the given value numBits right. If the given value is a long value, it will return +#' (Signed) shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @param y column to compute on. @@ -2777,7 +2779,7 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. @@ -2823,7 +2825,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' rand #' -#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -2852,7 +2855,8 @@ setMethod("rand", signature(seed = "numeric"), #' randn #' -#' Generate a column with i.i.d. samples from the standard normal distribution. +#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' the standard normal distribution. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -3442,8 +3446,8 @@ setMethod("size", #' sort_array #' -#' Sorts the input array for the given column in ascending order, -#' according to the natural ordering of the array elements. +#' Sorts the input array in ascending or descending order according +#' to the natural ordering of the array elements. #' #' @param x A Column to sort #' @param asc A logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 107e1c638be7..7653ca7bccec 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) -#' @rdname spark.glm +###################### Spark.ML Methods ########################## + +#' @rdname fitted #' @export -setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +setGeneric("fitted") #' @param x,y For \code{glm}: logical values indicating whether the response vector #' and model matrix used in the fitting process should be returned as @@ -1332,13 +1334,42 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @export setGeneric("rbind", signature = "...") +#' @rdname spark.als +#' @export +setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) + +#' @rdname spark.gaussianMixture +#' @export +setGeneric("spark.gaussianMixture", + function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) + +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + +#' @rdname spark.glm +#' @export +setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) + +#' @rdname spark.isoreg +#' @export +setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) + #' @rdname spark.kmeans #' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) -#' @rdname fitted +#' @rdname spark.kstest #' @export -setGeneric("fitted") +setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) + +#' @rdname spark.logit +#' @export +setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp #' @export @@ -1348,13 +1379,14 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) -#' @rdname spark.survreg +#' @rdname spark.randomForest #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.randomForest", + function(data, formula, ...) { standardGeneric("spark.randomForest") }) -#' @rdname spark.lda +#' @rdname spark.survreg #' @export -setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) #' @rdname spark.lda #' @export @@ -1364,20 +1396,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) -#' @rdname spark.isoreg -#' @export -setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) - -#' @rdname spark.gaussianMixture -#' @export -setGeneric("spark.gaussianMixture", - function(data, formula, ...) { - standardGeneric("spark.gaussianMixture") - }) - -#' @rdname spark.logit -#' @export -setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1385,11 +1403,3 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) - -#' @rdname spark.als -#' @export -setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) - -#' @rdname spark.kstest -#' @export -setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 629f284b79f3..1065b4b37d7f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -102,6 +102,34 @@ setClass("KSTest", representation(jobj = "jobj")) #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -110,9 +138,10 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.survreg} +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} NULL @@ -124,8 +153,10 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -619,7 +650,7 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @return \code{summary} returns the model's features, coefficients, k, size and cluster. #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -664,15 +695,15 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param data SparkDataFrame for training #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. +#' @param regParam the regularization parameter. #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination #' of L1 and L2. Default is 0.0 which is an L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. +#' @param fitIntercept whether to fit an intercept term. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". +#' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". @@ -690,11 +721,11 @@ setMethod("predict", signature(object = "KMeansModel"), #' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. #' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". +#' are large, this param could be adjusted to a larger size. +#' @param probabilityCol column name for predicted class conditional probabilities. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model #' @rdname spark.logit @@ -776,8 +807,10 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel #' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. +#' @return \code{summary} returns the Binary Logistic regression results of a given model as list, +#' including roc, areaUnderROC, pr, fMeasureByThreshold, precisionByThreshold, +#' recallByThreshold, totalIterations, objectiveHistory. Note that Multinomial logistic +#' regression summary is not available now. #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -1122,6 +1155,14 @@ read.ml <- function(path) { new("ALSModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1177,13 +1218,13 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' #' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". #' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' each iteration of mini-batch gradient descent, in range (0, 1]. #' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for #' the prior placed on topic distributions over terms, default -1 to set automatically on the #' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size @@ -1244,7 +1285,7 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, +#' @return \code{summary} returns a list containing the model's features, coefficients, #' intercept and log(scale) #' @rdname spark.survreg #' @export @@ -1332,7 +1373,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @return \code{summary} returns the model's lambda, mu, sigma, k, dim and posterior. #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1617,3 +1658,451 @@ print.summary.KSTest <- function(x, ...) { cat(summaryStr, "\n") invisible(x) } + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' df <- createDataFrame(iris) +#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), as.character(probabilityCol), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(RandomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(RandomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Get the summary of a Random Forest Classification Model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +# Prints the summary of Random Forest Regression Model + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc6d591bb2f4..6b4a2f2fdc85 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -154,6 +154,7 @@ sparkR.sparkContext <- function( packages <- processSparkPackages(sparkPackages) existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) if (existingPort != "") { if (length(packages) != 0) { warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell", @@ -187,6 +188,7 @@ sparkR.sparkContext <- function( backendPort <- readInt(f) monitorPort <- readInt(f) rLibPath <- readString(f) + connectionTimeout <- readInt(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || @@ -194,7 +196,9 @@ sparkR.sparkContext <- function( length(rLibPath) != 1) { stop("JVM failed to launch") } - assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".monitorConn", + socketConnection(port = monitorPort, timeout = connectionTimeout), + envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -204,7 +208,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort) + connectBackend("localhost", backendPort, timeout = connectionTimeout) }, error = function(err) { stop("Failed to connect JVM\n") diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c4e78cbb804d..20004549cc03 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -338,21 +338,41 @@ varargsToEnv <- function(...) { # into string. varargsToStrEnv <- function(...) { pairs <- list(...) + nameList <- names(pairs) env <- new.env() - for (name in names(pairs)) { - value <- pairs[[name]] - if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { - stop(paste0("Unsupported type for ", name, " : ", class(value), - ". Supported types are logical, numeric, character and NULL.")) - } - if (is.logical(value)) { - env[[name]] <- tolower(as.character(value)) - } else if (is.null(value)) { - env[[name]] <- value - } else { - env[[name]] <- as.character(value) + ignoredNames <- list() + + if (is.null(nameList)) { + # When all arguments are not named, names(..) returns NULL. + ignoredNames <- pairs + } else { + for (i in seq_along(pairs)) { + name <- nameList[i] + value <- pairs[i] + if (identical(name, "")) { + # When some of arguments are not named, name is "". + ignoredNames <- append(ignoredNames, value) + } else { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL."), call. = FALSE) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } } } + + if (length(ignoredNames) != 0) { + warning(paste0("Unnamed arguments ignored: ", paste(ignoredNames, collapse = ", "), "."), + call. = FALSE) + } env } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 6d1fccc7c058..33e9d0d267ac 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -64,6 +64,16 @@ test_that("spark.glm and predict", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # binomial family + binomialTraining <- training[training$Species %in% c("versicolor", "virginica"), ] + model <- spark.glm(binomialTraining, Species ~ Sepal_Length + Sepal_Width, + family = binomial(link = "logit")) + prediction <- predict(model, binomialTraining) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica", + "versicolor", "virginica", "versicolor", "virginica", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) + # poisson family model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, family = poisson(link = identity)) @@ -128,12 +138,12 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # Test spark.glm works with weighted dataset - a1 <- c(0, 1, 2, 3) - a2 <- c(5, 2, 1, 3) - w <- c(1, 2, 3, 4) - b <- c(1, 0, 1, 0) + a1 <- c(0, 1, 2, 3, 4) + a2 <- c(5, 2, 1, 3, 2) + w <- c(1, 2, 3, 4, 5) + b <- c(1, 0, 1, 0, 0) data <- as.data.frame(cbind(a1, a2, w, b)) - df <- suppressWarnings(createDataFrame(data)) + df <- createDataFrame(data) stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) @@ -158,7 +168,7 @@ test_that("spark.glm summary", { data <- as.data.frame(cbind(a1, a2, b)) df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) - expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result }) test_that("spark.glm save/load", { @@ -575,7 +585,7 @@ test_that("spark.isotonicRegression", { feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) data <- as.data.frame(cbind(label, feature, weight)) - df <- suppressWarnings(createDataFrame(data)) + df <- createDataFrame(data) model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, weightCol = "weight") @@ -871,4 +881,140 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) +test_that("spark.randomForest Regression", { + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, + 63.736, 64.296, 64.868, 64.300, + 66.709, 67.697, 67.966, 67.252, + 68.866, 69.593, 69.195, 69.658), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) +}) + +test_that("spark.randomForest Classification", { + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) +}) + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) +}) + sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9289db57b6d6..ee48baa59c7a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() @@ -2659,7 +2659,15 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # It makes sure that we can omit path argument in write.df API and then it calls # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), - "Error in save : illegal argument - 'path' is not specified") + "Error in save : illegal argument - Expected exactly one path to be specified") + expect_error(write.json(df, jsonPath), + "Error in json : analysis error - path file:.*already exists") + expect_error(write.text(df, jsonPath), + "Error in text : analysis error - path file:.*already exists") + expect_error(write.orc(df, jsonPath), + "Error in orc : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath), + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2679,6 +2687,11 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") + expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") + expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") + expect_error(read.parquet("arbitrary_path"), + "Error in parquet : analysis error - Path does not exist") # Arguments checking in R side. expect_error(read.df(path = c(3)), @@ -2686,6 +2699,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) + + expect_warning(read.json(jsonPath, a = 1, 2, 3, "a"), + "Unnamed arguments ignored: 2, 3, a.") }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index a20254e9b3fa..607c407f04f9 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -224,6 +224,8 @@ test_that("varargsToStrEnv", { expect_error(varargsToStrEnv(a = list(1, "a")), paste0("Unsupported type for a : list. Supported types are logical, ", "numeric, character and NULL.")) + expect_warning(varargsToStrEnv(a = 1, 2, 3, 4), "Unnamed arguments ignored: 2, 3, 4.") + expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) sparkR.session.stop() diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index b92e6be995ca..3a318b71ea06 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,6 +18,7 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") @@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) +inputCon <- socketConnection( + port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) while (TRUE) { ready <- socketSelect(list(inputCon)) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index cfe41ded200c..03e745014786 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -90,6 +90,7 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]] suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") -outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") +inputCon <- socketConnection( + port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) +outputCon <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 1fd6ef4a7125..42e2d9abdeb5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -68,16 +68,16 @@
{{#applications}}" +
+ threadCell.after("" +
stackTraceText + " ")
} else {
if (!forceAdd) {
@@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) {
$("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover");
$("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover");
$("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover");
+ $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover");
}
function onSearchStringChange() {
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js
index e37307aa1f70..0fa1fcf25f8b 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js
@@ -15,6 +15,12 @@
* limitations under the License.
*/
+var uiRoot = "";
+
+function setUIRoot(val) {
+ uiRoot = val;
+}
+
function collapseTablePageLoad(name, table){
if (window.localStorage.getItem(name) == "true") {
// Set it to false so that the click function can revert it
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 4694790c72cd..25a3d609a6b0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -183,6 +183,8 @@ class SparkContext(config: SparkConf) extends Logging {
// log out Spark Version in Spark driver log
logInfo(s"Running Spark version $SPARK_VERSION")
+ warnDeprecatedVersions()
+
/* ------------------------------------------------------------------------------------- *
| Private variables. These variables keep the internal state of the context, and are |
| not accessible by the outside world. They're mutable since we want to initialize all |
@@ -346,6 +348,16 @@ class SparkContext(config: SparkConf) extends Logging {
value
}
+ private def warnDeprecatedVersions(): Unit = {
+ val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3)
+ if (javaVersion.length >= 2 && javaVersion(1).toInt == 7) {
+ logWarning("Support for Java 7 is deprecated as of Spark 2.0.0")
+ }
+ if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) {
+ logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0")
+ }
+ }
+
/** Control our logLevel. This overrides any user-defined log settings.
* @param logLevel The desired log level as a string.
* Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
@@ -1716,29 +1728,12 @@ class SparkContext(config: SparkConf) extends Logging {
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
- if (master == "yarn" && deployMode == "cluster") {
- // In order for this to work in yarn cluster mode the user must specify the
- // --addJars option to the client to upload the file into the distributed cache
- // of the AM to make it show up in the current working directory.
- val fileName = new Path(uri.getPath).getName()
- try {
- env.rpcEnv.fileServer.addJar(new File(fileName))
- } catch {
- case e: Exception =>
- // For now just log an error but allow to go through so spark examples work.
- // The spark examples don't really need the jar distributed since its also
- // the app jar.
- logError("Error adding jar (" + e + "), was the --addJars option used?")
- null
- }
- } else {
- try {
- env.rpcEnv.fileServer.addJar(new File(uri.getPath))
- } catch {
- case exc: FileNotFoundException =>
- logError(s"Jar not found at $path")
- null
- }
+ try {
+ env.rpcEnv.fileServer.addJar(new File(uri.getPath))
+ } catch {
+ case exc: FileNotFoundException =>
+ logError(s"Jar not found at $path")
+ null
}
// A JAR file which exists locally on every worker node
case "local" =>
@@ -1762,8 +1757,26 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def listJars(): Seq[String] = addedJars.keySet.toSeq
- // Shut down the SparkContext.
- def stop() {
+ /**
+ * Shut down the SparkContext.
+ */
+ def stop(): Unit = {
+ if (env.rpcEnv.isInRPCThread) {
+ // `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
+ // We should launch a new thread to call `stop` to avoid dead-lock.
+ new Thread("stop-spark-context") {
+ setDaemon(true)
+
+ override def run(): Unit = {
+ _stop()
+ }
+ }.start()
+ } else {
+ _stop()
+ }
+ }
+
+ private def _stop() {
if (LiveListenerBus.withinListenerThread.value) {
throw new SparkException(
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 6550d703bc86..46e22b215b8e 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -20,14 +20,14 @@ package org.apache.spark
import java.io.IOException
import java.text.NumberFormat
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{Date, Locale}
import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred._
import org.apache.hadoop.mapreduce.TaskType
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.io.SparkHadoopWriterUtils
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
import org.apache.spark.util.SerializableJobConf
@@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable {
def setup(jobid: Int, splitid: Int, attemptid: Int) {
setIDs(jobid, splitid, attemptid)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now),
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now),
jobid, splitID, attemptID, conf.value)
}
def open() {
- val numfmt = NumberFormat.getInstance()
+ val numfmt = NumberFormat.getInstance(Locale.US)
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
@@ -153,29 +153,8 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable {
splitID = splitid
attemptID = attemptid
- jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid))
taID = new SerializableWritable[TaskAttemptID](
new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID))
}
}
-
-private[spark]
-object SparkHadoopWriter {
- def createJobID(time: Date, id: Int): JobID = {
- val formatter = new SimpleDateFormat("yyyyMMddHHmmss")
- val jobtrackerID = formatter.format(time)
- new JobID(jobtrackerID, id)
- }
-
- def createPathFromString(path: String, conf: JobConf): Path = {
- if (path == null) {
- throw new IllegalArgumentException("Output path is null")
- }
- val outputPath = new Path(path)
- val fs = outputPath.getFileSystem(conf)
- if (fs == null) {
- throw new IllegalArgumentException("Incorrectly formatted output path")
- }
- outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 41d0a85ee3ad..550746c552d0 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
-import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+import io.netty.handler.timeout.ReadTimeoutHandler
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
@@ -43,7 +44,10 @@ private[spark] class RBackend {
def init(): Int = {
val conf = new SparkConf()
- bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2))
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+ bossGroup = new NioEventLoopGroup(
+ conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
@@ -63,6 +67,7 @@ private[spark] class RBackend {
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
+ .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
.addLast("handler", handler)
}
})
@@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging {
val boundPort = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
+ // Connection timeout is set by socket client. To make it configurable we will pass the
+ // timeout value to client inside the temp file
+ val conf = new SparkConf()
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
// tell the R process via temporary file
val path = args(0)
@@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(boundPort)
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
+ dos.writeInt(backendConnectionTimeout)
dos.close()
f.renameTo(new File(path))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 1422ef888fd4..9f5afa29d6d2 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -18,16 +18,19 @@
package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.util.concurrent.TimeUnit
import scala.collection.mutable.HashMap
import scala.language.existentials
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.channel.ChannelHandler.Sharable
+import io.netty.handler.timeout.ReadTimeoutException
import org.apache.spark.api.r.SerDe._
import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Handler for RBackend
@@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend)
writeString(dos, s"Error: unknown method $methodName")
}
} else {
+ // To avoid timeouts when reading results in SparkR driver, we will be regularly sending
+ // heartbeat responses. We use special code +1 to signal the client that backend is
+ // alive and it should continue blocking for result.
+ val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread")
+ val pingRunner = new Runnable {
+ override def run(): Unit = {
+ val pingBaos = new ByteArrayOutputStream()
+ val pingDaos = new DataOutputStream(pingBaos)
+ writeInt(pingDaos, +1)
+ ctx.write(pingBaos.toByteArray)
+ }
+ }
+ val conf = new SparkConf()
+ val heartBeatInterval = conf.getInt(
+ "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL)
+ val backendConnectionTimeout = conf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
+ val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)
+
+ execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS)
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ execService.shutdown()
+ execService.awaitTermination(1, TimeUnit.SECONDS)
}
val reply = bos.toByteArray
@@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend)
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- // Close the connection when an exception is raised.
- cause.printStackTrace()
- ctx.close()
+ cause match {
+ case timeout: ReadTimeoutException =>
+ // Do nothing. We don't want to timeout on read
+ logWarning("Ignoring read timeout in RBackendHandler")
+ case _ =>
+ // Close the connection when an exception is raised.
+ cause.printStackTrace()
+ ctx.close()
+ }
}
def handleMethodCall(
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 496fdf851f7d..7ef64723d959 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -333,6 +333,8 @@ private[r] object RRunner {
var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
rCommand = sparkConf.get("spark.r.command", rCommand)
+ val rConnectionTimeout = sparkConf.getInt(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
val rOptions = "--vanilla"
val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
@@ -344,6 +346,7 @@ private[r] object RRunner {
pb.environment().put("R_TESTS", "")
pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+ pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
val errThread = startStdoutThread(proc)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
index 77825e75e513..fdd8cf62f0e5 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -84,7 +84,6 @@ private[spark] object RUtils {
}
} else {
// Otherwise, assume the package is local
- // TODO: support this for Mesos
val sparkRPkgPath = localSparkRPackagePath.getOrElse {
throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
new file mode 100644
index 000000000000..af67cbbce4e5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+private[spark] object SparkRDefaults {
+
+ // Default value for spark.r.backendConnectionTimeout config
+ val DEFAULT_CONNECTION_TIMEOUT: Int = 6000
+
+ // Default value for spark.r.heartBeatInterval config
+ val DEFAULT_HEARTBEAT_INTERVAL: Int = 100
+
+ // Default value for spark.r.numRBackendThreads config
+ val DEFAULT_NUM_RBACKEND_THREADS = 2
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index d0466830b217..6eb53a825220 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkUserAppException}
-import org.apache.spark.api.r.{RBackend, RUtils}
+import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults}
import org.apache.spark.util.RedirectThread
/**
@@ -51,6 +51,10 @@ object RRunner {
cmd
}
+ // Connection timeout set by R process on its connection to RBackend in seconds.
+ val backendConnectionTimeout = sys.props.getOrElse(
+ "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString)
+
// Check if the file path exists.
// If not, change directory to current working directory for YARN cluster mode
val rF = new File(rFile)
@@ -81,6 +85,7 @@ object RRunner {
val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+ env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout)
val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
// Put the R package directories into an env variable of comma-separated paths
env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 3f54ecc17ac3..23156072c3eb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,7 +21,7 @@ import java.io.IOException
import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import java.text.DateFormat
-import java.util.{Arrays, Comparator, Date}
+import java.util.{Arrays, Comparator, Date, Locale}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
@@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging {
* @return a printable string value.
*/
private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = {
- val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT)
+ val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US)
val buffer = new StringBuilder(128)
buffer.append(token.toString)
try {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 5c052286099f..c70061bc5b5b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -322,7 +322,7 @@ object SparkSubmit {
}
// Require all R files to be local
- if (args.isR && !isYarnCluster) {
+ if (args.isR && !isYarnCluster && !isMesosCluster) {
if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}")
}
@@ -330,9 +330,6 @@ object SparkSubmit {
// The following modes are not supported or applicable
(clusterManager, deployMode) match {
- case (MESOS, CLUSTER) if args.isR =>
- printErrorAndExit("Cluster deploy mode is currently not supported for R " +
- "applications on Mesos clusters.")
case (STANDALONE, CLUSTER) if args.isPython =>
printErrorAndExit("Cluster deploy mode is currently not supported for python " +
"applications on standalone clusters.")
@@ -410,9 +407,9 @@ object SparkSubmit {
printErrorAndExit("Distributing R packages with standalone cluster is not supported.")
}
- // TODO: Support SparkR with mesos cluster
- if (args.isR && clusterManager == MESOS) {
- printErrorAndExit("SparkR is not supported for Mesos cluster.")
+ // TODO: Support distributing R packages with mesos cluster
+ if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) {
+ printErrorAndExit("Distributing R packages with mesos cluster is not supported.")
}
// If we're running an R app, set the main class to our specific R runner
@@ -598,6 +595,9 @@ object SparkSubmit {
if (args.pyFiles != null) {
sysProps("spark.submit.pyFiles") = args.pyFiles
}
+ } else if (args.isR) {
+ // Second argument is main class
+ childArgs += (args.primaryResource, "")
} else {
childArgs += (args.primaryResource, args.mainClass)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 8c91aa15167c..4618e6117a4f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy.master
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{Date, Locale}
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
@@ -51,7 +51,8 @@ private[deploy] class Master(
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
- private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ // For application IDs
+ private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000
private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 0bedd9a20a96..8b1c6bf2e5fd 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker
import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.{Date, UUID}
+import java.util.{Date, Locale, UUID}
import java.util.concurrent._
import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
@@ -68,7 +68,7 @@ private[deploy] class Worker(
ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread"))
// For worker and executor IDs
- private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
index f66510b6f977..59404e08895a 100644
--- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -27,6 +27,9 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
+import org.apache.spark.internal.config
+import org.apache.spark.SparkContext
+
/**
* A general format for reading whole files in as streams, byte arrays,
* or other functions to be added
@@ -40,9 +43,14 @@ private[spark] abstract class StreamFileInputFormat[T]
* Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
* which is set through setMaxSplitSize
*/
- def setMinPartitions(context: JobContext, minPartitions: Int) {
- val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum
- val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong
+ def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) {
+ val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES)
+ val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES)
+ val defaultParallelism = sc.defaultParallelism
+ val files = listStatus(context).asScala
+ val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum
+ val bytesPerCore = totalBytes / defaultParallelism
+ val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
super.setMaxSplitSize(maxSplitSize)
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 497ca92c7bc6..4a3e3d5c79ef 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -206,4 +206,17 @@ package object config {
"encountering corrupt files and contents that have been read will still be returned.")
.booleanConf
.createWithDefault(false)
+
+ private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes")
+ .doc("The maximum number of bytes to pack into a single partition when reading files.")
+ .longConf
+ .createWithDefault(128 * 1024 * 1024)
+
+ private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes")
+ .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" +
+ " the same time. This is used when putting multiple files into a partition. It's better to" +
+ " over estimate, then the partitions with small files will be faster than partitions with" +
+ " bigger files.")
+ .longConf
+ .createWithDefault(4 * 1024 * 1024)
}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
new file mode 100644
index 000000000000..fb8020585cf8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import org.apache.hadoop.mapreduce._
+
+import org.apache.spark.util.Utils
+
+
+/**
+ * An interface to define how a single Spark job commits its outputs. Two notes:
+ *
+ * 1. Implementations must be serializable, as the committer instance instantiated on the driver
+ * will be used for tasks on executors.
+ * 2. Implementations should have a constructor with either 2 or 3 arguments:
+ * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean).
+ * 3. A committer should not be reused across multiple Spark jobs.
+ *
+ * The proper call sequence is:
+ *
+ * 1. Driver calls setupJob.
+ * 2. As part of each task's execution, executor calls setupTask and then commitTask
+ * (or abortTask if task failed).
+ * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job
+ * failed to execute (e.g. too many failed tasks), the job should call abortJob.
+ */
+abstract class FileCommitProtocol {
+ import FileCommitProtocol._
+
+ /**
+ * Setups up a job. Must be called on the driver before any other methods can be invoked.
+ */
+ def setupJob(jobContext: JobContext): Unit
+
+ /**
+ * Commits a job after the writes succeed. Must be called on the driver.
+ */
+ def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit
+
+ /**
+ * Aborts a job after the writes fail. Must be called on the driver.
+ *
+ * Calling this function is a best-effort attempt, because it is possible that the driver
+ * just crashes (or killed) before it can call abort.
+ */
+ def abortJob(jobContext: JobContext): Unit
+
+ /**
+ * Sets up a task within a job.
+ * Must be called before any other task related methods can be invoked.
+ */
+ def setupTask(taskContext: TaskAttemptContext): Unit
+
+ /**
+ * Notifies the commit protocol to add a new file, and gets back the full path that should be
+ * used. Must be called on the executors when running tasks.
+ *
+ * Note that the returned temp file may have an arbitrary path. The commit protocol only
+ * promises that the file will be at the location specified by the arguments after job commit.
+ *
+ * A full file path consists of the following parts:
+ * 1. the base path
+ * 2. some sub-directory within the base path, used to specify partitioning
+ * 3. file prefix, usually some unique job id with the task id
+ * 4. bucket id
+ * 5. source specific file extension, e.g. ".snappy.parquet"
+ *
+ * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
+ * are left to the commit protocol implementation to decide.
+ */
+ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
+
+ /**
+ * Commits a task after the writes succeed. Must be called on the executors when running tasks.
+ */
+ def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage
+
+ /**
+ * Aborts a task after the writes have failed. Must be called on the executors when running tasks.
+ *
+ * Calling this function is a best-effort attempt, because it is possible that the executor
+ * just crashes (or killed) before it can call abort.
+ */
+ def abortTask(taskContext: TaskAttemptContext): Unit
+}
+
+
+object FileCommitProtocol {
+ class TaskCommitMessage(val obj: Any) extends Serializable
+
+ object EmptyTaskCommitMessage extends TaskCommitMessage(null)
+
+ /**
+ * Instantiates a FileCommitProtocol using the given className.
+ */
+ def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean)
+ : FileCommitProtocol = {
+ val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]]
+
+ // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean).
+ // If that doesn't exist, try the one with (jobId: string, outputPath: String).
+ try {
+ val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean])
+ ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean])
+ } catch {
+ case _: NoSuchMethodException =>
+ val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String])
+ ctor.newInstance(jobId, outputPath)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
new file mode 100644
index 000000000000..6b0bcb8f908b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import java.util.Date
+
+import org.apache.hadoop.conf.Configurable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+
+/**
+ * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter
+ * (from the newer mapreduce API, not the old mapred API).
+ *
+ * Unlike Hadoop's OutputCommitter, this implementation is serializable.
+ */
+class HadoopMapReduceCommitProtocol(jobId: String, path: String)
+ extends FileCommitProtocol with Serializable with Logging {
+
+ import FileCommitProtocol._
+
+ /** OutputCommitter from Hadoop is not serializable so marking it transient. */
+ @transient private var committer: OutputCommitter = _
+
+ protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
+ val format = context.getOutputFormatClass.newInstance()
+ // If OutputFormat is Configurable, we should set conf to it.
+ format match {
+ case c: Configurable => c.setConf(context.getConfiguration)
+ case _ => ()
+ }
+ format.getOutputCommitter(context)
+ }
+
+ override def newTaskTempFile(
+ taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
+ // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
+ // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
+ // the file name is fine and won't overflow.
+ val split = taskContext.getTaskAttemptID.getTaskID.getId
+ val filename = f"part-$split%05d-$jobId$ext"
+
+ val stagingDir: String = committer match {
+ // For FileOutputCommitter it has its own staging path called "work path".
+ case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path)
+ case _ => path
+ }
+
+ dir.map { d =>
+ new Path(new Path(stagingDir, d), filename).toString
+ }.getOrElse {
+ new Path(stagingDir, filename).toString
+ }
+ }
+
+ override def setupJob(jobContext: JobContext): Unit = {
+ // Setup IDs
+ val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0)
+ val taskId = new TaskID(jobId, TaskType.MAP, 0)
+ val taskAttemptId = new TaskAttemptID(taskId, 0)
+
+ // Set up the configuration object
+ jobContext.getConfiguration.set("mapred.job.id", jobId.toString)
+ jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
+ jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString)
+ jobContext.getConfiguration.setBoolean("mapred.task.is.map", true)
+ jobContext.getConfiguration.setInt("mapred.task.partition", 0)
+
+ val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId)
+ committer = setupCommitter(taskAttemptContext)
+ committer.setupJob(jobContext)
+ }
+
+ override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
+ committer.commitJob(jobContext)
+ }
+
+ override def abortJob(jobContext: JobContext): Unit = {
+ committer.abortJob(jobContext, JobStatus.State.FAILED)
+ }
+
+ override def setupTask(taskContext: TaskAttemptContext): Unit = {
+ committer = setupCommitter(taskContext)
+ committer.setupTask(taskContext)
+ }
+
+ override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
+ val attemptId = taskContext.getTaskAttemptID
+ SparkHadoopMapRedUtil.commitTask(
+ committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
+ EmptyTaskCommitMessage
+ }
+
+ override def abortTask(taskContext: TaskAttemptContext): Unit = {
+ committer.abortTask(taskContext)
+ }
+
+ /** Whether we are using a direct output committer */
+ def isDirectOutput(): Boolean = committer.getClass.getSimpleName.contains("Direct")
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala
new file mode 100644
index 000000000000..796439276a22
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import java.text.SimpleDateFormat
+import java.util.{Date, Locale}
+
+import scala.reflect.ClassTag
+import scala.util.DynamicVariable
+
+import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.{JobConf, JobID}
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+import org.apache.spark.{SparkConf, SparkException, TaskContext}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.OutputMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * A helper object that saves an RDD using a Hadoop OutputFormat
+ * (from the newer mapreduce API, not the old mapred API).
+ */
+private[spark]
+object SparkHadoopMapReduceWriter extends Logging {
+
+ /**
+ * Basic work flow of this command is:
+ * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to
+ * be issued.
+ * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
+ * rows within an RDD partition.
+ * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any
+ * exception is thrown during task commitment, also aborts that task.
+ * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
+ * thrown during job commitment, also aborts the job.
+ */
+ def write[K, V: ClassTag](
+ rdd: RDD[(K, V)],
+ hadoopConf: Configuration): Unit = {
+ // Extract context and configuration from RDD.
+ val sparkContext = rdd.context
+ val stageId = rdd.id
+ val sparkConf = rdd.conf
+ val conf = new SerializableConfiguration(hadoopConf)
+
+ // Set up a job.
+ val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date())
+ val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0)
+ val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId)
+ val format = jobContext.getOutputFormatClass
+
+ if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) {
+ // FileOutputFormat ignores the filesystem parameter
+ val jobFormat = format.newInstance
+ jobFormat.checkOutputSpecs(jobContext)
+ }
+
+ val committer = FileCommitProtocol.instantiate(
+ className = classOf[HadoopMapReduceCommitProtocol].getName,
+ jobId = stageId.toString,
+ outputPath = conf.value.get("mapred.output.dir"),
+ isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol]
+ committer.setupJob(jobContext)
+
+ // When speculation is on and output committer class name contains "Direct", we should warn
+ // users that they may loss data if they are using a direct output committer.
+ if (SparkHadoopWriterUtils.isSpeculationEnabled(sparkConf) && committer.isDirectOutput) {
+ val warningMessage =
+ s"$committer may be an output committer that writes data directly to " +
+ "the final location. Because speculation is enabled, this output committer may " +
+ "cause data loss (see the case in SPARK-10063). If possible, please use an output " +
+ "committer that does not have this behavior (e.g. FileOutputCommitter)."
+ logWarning(warningMessage)
+ }
+
+ // Try to write all RDD partitions as a Hadoop OutputFormat.
+ try {
+ val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => {
+ executeTask(
+ context = context,
+ jobTrackerId = jobTrackerId,
+ sparkStageId = context.stageId,
+ sparkPartitionId = context.partitionId,
+ sparkAttemptNumber = context.attemptNumber,
+ committer = committer,
+ hadoopConf = conf.value,
+ outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]],
+ iterator = iter)
+ })
+
+ committer.commitJob(jobContext, ret)
+ logInfo(s"Job ${jobContext.getJobID} committed.")
+ } catch {
+ case cause: Throwable =>
+ logError(s"Aborting job ${jobContext.getJobID}.", cause)
+ committer.abortJob(jobContext)
+ throw new SparkException("Job aborted.", cause)
+ }
+ }
+
+ /** Write a RDD partition out in a single Spark task. */
+ private def executeTask[K, V: ClassTag](
+ context: TaskContext,
+ jobTrackerId: String,
+ sparkStageId: Int,
+ sparkPartitionId: Int,
+ sparkAttemptNumber: Int,
+ committer: FileCommitProtocol,
+ hadoopConf: Configuration,
+ outputFormat: Class[_ <: OutputFormat[K, V]],
+ iterator: Iterator[(K, V)]): TaskCommitMessage = {
+ // Set up a task.
+ val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE,
+ sparkPartitionId, sparkAttemptNumber)
+ val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId)
+ committer.setupTask(taskContext)
+
+ val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] =
+ SparkHadoopWriterUtils.initHadoopOutputMetrics(context)
+
+ // Initiate the writer.
+ val taskFormat = outputFormat.newInstance()
+ // If OutputFormat is Configurable, we should set conf to it.
+ taskFormat match {
+ case c: Configurable => c.setConf(hadoopConf)
+ case _ => ()
+ }
+ val writer = taskFormat.getRecordWriter(taskContext)
+ .asInstanceOf[RecordWriter[K, V]]
+ require(writer != null, "Unable to obtain RecordWriter")
+ var recordsWritten = 0L
+
+ // Write all rows in RDD partition.
+ try {
+ val ret = Utils.tryWithSafeFinallyAndFailureCallbacks {
+ while (iterator.hasNext) {
+ val pair = iterator.next()
+ writer.write(pair._1, pair._2)
+
+ // Update bytes written metric every few records
+ SparkHadoopWriterUtils.maybeUpdateOutputMetrics(
+ outputMetricsAndBytesWrittenCallback, recordsWritten)
+ recordsWritten += 1
+ }
+
+ committer.commitTask(taskContext)
+ }(catchBlock = {
+ committer.abortTask(taskContext)
+ logError(s"Task ${taskContext.getTaskAttemptID} aborted.")
+ }, finallyBlock = writer.close(taskContext))
+
+ outputMetricsAndBytesWrittenCallback.foreach {
+ case (om, callback) =>
+ om.setBytesWritten(callback())
+ om.setRecordsWritten(recordsWritten)
+ }
+
+ ret
+ } catch {
+ case t: Throwable =>
+ throw new SparkException("Task failed while writing rows", t)
+ }
+ }
+}
+
+private[spark]
+object SparkHadoopWriterUtils {
+
+ private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+
+ def createJobID(time: Date, id: Int): JobID = {
+ val jobtrackerID = createJobTrackerID(time)
+ new JobID(jobtrackerID, id)
+ }
+
+ def createJobTrackerID(time: Date): String = {
+ new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time)
+ }
+
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ }
+
+ // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
+ // setting can take effect:
+ def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = {
+ val validationDisabled = disableOutputSpecValidation.value
+ val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
+ enabledInConf && !validationDisabled
+ }
+
+ def isSpeculationEnabled(conf: SparkConf): Boolean = {
+ conf.getBoolean("spark.speculation", false)
+ }
+
+ // TODO: these don't seem like the right abstractions.
+ // We should abstract the duplicate code in a less awkward way.
+
+ // return type: (output metrics, bytes written callback), defined only if the latter is defined
+ def initHadoopOutputMetrics(
+ context: TaskContext): Option[(OutputMetrics, () => Long)] = {
+ val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
+ bytesWrittenCallback.map { b =>
+ (context.taskMetrics().outputMetrics, b)
+ }
+ }
+
+ def maybeUpdateOutputMetrics(
+ outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)],
+ recordsWritten: Long): Unit = {
+ if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) {
+ outputMetricsAndBytesWrittenCallback.foreach {
+ case (om, callback) =>
+ om.setBytesWritten(callback())
+ om.setRecordsWritten(recordsWritten)
+ }
+ }
+ }
+
+ /**
+ * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
+ * basis; see SPARK-4835 for more details.
+ */
+ val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
index b54885b7ff8b..3f7cfd9d2c11 100644
--- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
@@ -76,7 +76,7 @@ object HiveCatalogMetrics extends Source {
val METRIC_PARTITIONS_FETCHED = metricRegistry.counter(MetricRegistry.name("partitionsFetched"))
/**
- * Tracks the total number of files discovered off of the filesystem by ListingFileCatalog.
+ * Tracks the total number of files discovered off of the filesystem by InMemoryFileIndex.
*/
val METRIC_FILES_DISCOVERED = metricRegistry.counter(MetricRegistry.name("filesDiscovered"))
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
index 41832e835474..50d977a92da5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{Partition, SparkContext}
import org.apache.spark.input.StreamFileInputFormat
private[spark] class BinaryFileRDD[T](
- sc: SparkContext,
+ @transient private val sc: SparkContext,
inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
keyClass: Class[String],
valueClass: Class[T],
@@ -43,7 +43,7 @@ private[spark] class BinaryFileRDD[T](
case _ =>
}
val jobContext = new JobContextImpl(conf, jobId)
- inputFormat.setMinPartitions(jobContext, minPartitions)
+ inputFormat.setMinPartitions(sc, jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index e1cf3938de09..36a2f5c87e37 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{Date, Locale}
import scala.collection.immutable.Map
import scala.reflect.ClassTag
@@ -243,7 +243,8 @@ class HadoopRDD[K, V](
var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime),
+ HadoopRDD.addLocalConfiguration(
+ new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime),
context.stageId, theSplit.index, context.attemptNumber, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index baf31fb65887..488e777fea37 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{Date, Locale}
import scala.reflect.ClassTag
@@ -79,7 +79,7 @@ class NewHadoopRDD[K, V](
// private val serializableConf = new SerializableWritable(_conf)
private val jobTrackerId: String = {
- val formatter = new SimpleDateFormat("yyyyMMddHHmmss")
+ val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
formatter.format(new Date())
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 068f4ed8ad74..f9b9631d9e7c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -18,33 +18,31 @@
package org.apache.spark.rdd
import java.nio.ByteBuffer
-import java.text.SimpleDateFormat
-import java.util.{Date, HashMap => JHashMap}
+import java.util.{HashMap => JHashMap}
import scala.collection.{mutable, Map}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
-import scala.util.DynamicVariable
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
-import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType}
-import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.OutputMetrics
+import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapReduceCommitProtocol, SparkHadoopMapReduceWriter, SparkHadoopWriterUtils}
import org.apache.spark.internal.Logging
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.util.random.StratifiedSamplingUtils
@@ -1060,7 +1058,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
FileOutputFormat.setOutputPath(hadoopConf,
- SparkHadoopWriter.createPathFromString(path, hadoopConf))
+ SparkHadoopWriterUtils.createPathFromString(path, hadoopConf))
saveAsHadoopDataset(hadoopConf)
}
@@ -1076,80 +1074,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* result of using direct output committer with speculation enabled.
*/
def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope {
- // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
- val hadoopConf = conf
- val job = NewAPIHadoopJob.getInstance(hadoopConf)
- val formatter = new SimpleDateFormat("yyyyMMddHHmmss")
- val jobtrackerID = formatter.format(new Date())
- val stageId = self.id
- val jobConfiguration = job.getConfiguration
- val wrappedConf = new SerializableConfiguration(jobConfiguration)
- val outfmt = job.getOutputFormatClass
- val jobFormat = outfmt.newInstance
-
- if (isOutputSpecValidationEnabled) {
- // FileOutputFormat ignores the filesystem parameter
- jobFormat.checkOutputSpecs(job)
- }
-
- val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => {
- val config = wrappedConf.value
- /* "reduce task" */
- val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId,
- context.attemptNumber)
- val hadoopContext = new TaskAttemptContextImpl(config, attemptId)
- val format = outfmt.newInstance
- format match {
- case c: Configurable => c.setConf(config)
- case _ => ()
- }
- val committer = format.getOutputCommitter(hadoopContext)
- committer.setupTask(hadoopContext)
-
- val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] =
- initHadoopOutputMetrics(context)
-
- val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
- require(writer != null, "Unable to obtain RecordWriter")
- var recordsWritten = 0L
- Utils.tryWithSafeFinallyAndFailureCallbacks {
- while (iter.hasNext) {
- val pair = iter.next()
- writer.write(pair._1, pair._2)
-
- // Update bytes written metric every few records
- maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
- recordsWritten += 1
- }
- }(finallyBlock = writer.close(hadoopContext))
- committer.commitTask(hadoopContext)
- outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
- om.setBytesWritten(callback())
- om.setRecordsWritten(recordsWritten)
- }
- 1
- } : Int
-
- val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0)
- val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId)
- val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
-
- // When speculation is on and output committer class name contains "Direct", we should warn
- // users that they may loss data if they are using a direct output committer.
- val speculationEnabled = self.conf.getBoolean("spark.speculation", false)
- val outputCommitterClass = jobCommitter.getClass.getSimpleName
- if (speculationEnabled && outputCommitterClass.contains("Direct")) {
- val warningMessage =
- s"$outputCommitterClass may be an output committer that writes data directly to " +
- "the final location. Because speculation is enabled, this output committer may " +
- "cause data loss (see the case in SPARK-10063). If possible, please use an output " +
- "committer that does not have this behavior (e.g. FileOutputCommitter)."
- logWarning(warningMessage)
- }
-
- jobCommitter.setupJob(jobTaskContext)
- self.context.runJob(self, writeShard)
- jobCommitter.commitJob(jobTaskContext)
+ SparkHadoopMapReduceWriter.write(
+ rdd = self,
+ hadoopConf = conf)
}
/**
@@ -1178,7 +1105,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
- if (isOutputSpecValidationEnabled) {
+ if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) {
// FileOutputFormat ignores the filesystem parameter
val ignoredFs = FileSystem.get(hadoopConf)
hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
@@ -1193,7 +1120,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] =
- initHadoopOutputMetrics(context)
+ SparkHadoopWriterUtils.initHadoopOutputMetrics(context)
writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
@@ -1205,7 +1132,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
// Update bytes written metric every few records
- maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
+ SparkHadoopWriterUtils.maybeUpdateOutputMetrics(
+ outputMetricsAndBytesWrittenCallback, recordsWritten)
recordsWritten += 1
}
}(finallyBlock = writer.close())
@@ -1220,29 +1148,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.commitJob()
}
- // TODO: these don't seem like the right abstractions.
- // We should abstract the duplicate code in a less awkward way.
-
- // return type: (output metrics, bytes written callback), defined only if the latter is defined
- private def initHadoopOutputMetrics(
- context: TaskContext): Option[(OutputMetrics, () => Long)] = {
- val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
- bytesWrittenCallback.map { b =>
- (context.taskMetrics().outputMetrics, b)
- }
- }
-
- private def maybeUpdateOutputMetrics(
- outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)],
- recordsWritten: Long): Unit = {
- if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) {
- outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
- om.setBytesWritten(callback())
- om.setRecordsWritten(recordsWritten)
- }
- }
- }
-
/**
* Return an RDD with the keys of each tuple.
*/
@@ -1258,22 +1163,4 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def valueClass: Class[_] = vt.runtimeClass
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
-
- // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
- // setting can take effect:
- private def isOutputSpecValidationEnabled: Boolean = {
- val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value
- val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
- enabledInConf && !validationDisabled
- }
-}
-
-private[spark] object PairRDDFunctions {
- val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
-
- /**
- * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
- * basis; see SPARK-4835 for more details.
- */
- val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index db535de9e9bb..e018af35cb18 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag](
}
/**
- * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
- * performance API to be used carefully only if we are sure that the RDD elements are
+ * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
+ * It is a performance API to be used carefully only if we are sure that the RDD elements are
* serializable and don't require closure cleaning.
*
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
* which should be `false` unless this is a pair RDD and the input function doesn't modify
* the keys.
*/
+ private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = withScope {
+ new MapPartitionsRDD(
+ this,
+ (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
+ preservesPartitioning)
+ }
+
+ /**
+ * [performance] Spark's internal mapPartitions method that skips closure cleaning.
+ */
private[spark] def mapPartitionsInternal[U: ClassTag](
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index eac901d10067..9f800e3a0953 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -239,12 +239,17 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
val fileInputStream = fs.open(partitionerFilePath, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
- val deserializeStream = serializer.deserializeStream(fileInputStream)
- val partitioner = Utils.tryWithSafeFinally[Partitioner] {
- deserializeStream.readObject[Partitioner]
+ val partitioner = Utils.tryWithSafeFinally {
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+ Utils.tryWithSafeFinally {
+ deserializeStream.readObject[Partitioner]
+ } {
+ deserializeStream.close()
+ }
} {
- deserializeStream.close()
+ fileInputStream.close()
}
+
logDebug(s"Read partitioner from $partitionerFilePath")
Some(partitioner)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 579122868afc..bbc416381490 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -147,6 +147,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
*/
def openChannel(uri: String): ReadableByteChannel
+ /**
+ * Return if the current thread is a RPC thread.
+ */
+ def isInRPCThread: Boolean
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index a02cf30a5d83..67baabd2cbff 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
+ NettyRpcEnv.rpcThreadFlag.value = true
try {
while (true) {
try {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index e51649a1ecce..0b8cd144a216 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv(
}
+ override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
}
private[netty] object NettyRpcEnv extends Logging {
+ private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)
+
/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f2517401cb76..7fde34d8974c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1089,7 +1089,8 @@ class DAGScheduler(
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && !updates.isZero) {
stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
- event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
+ event.taskInfo.setAccumulables(
+ acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index 3eff8d952bfd..0bd5a6bc59a9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -53,13 +53,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
sourceName: String,
maybeTruncated: Boolean = false,
eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = {
+ val lines = Source.fromInputStream(logData).getLines()
+ replay(lines, sourceName, maybeTruncated, eventsFilter)
+ }
+ /**
+ * Overloaded variant of [[replay()]] which accepts an iterator of lines instead of an
+ * [[InputStream]]. Exposed for use by custom ApplicationHistoryProvider implementations.
+ */
+ def replay(
+ lines: Iterator[String],
+ sourceName: String,
+ maybeTruncated: Boolean,
+ eventsFilter: ReplayEventsFilter): Unit = {
var currentLine: String = null
var lineNumber: Int = 0
try {
- val lineEntries = Source.fromInputStream(logData)
- .getLines()
+ val lineEntries = lines
.zipWithIndex
.filter { case (line, _) => eventsFilter(line) }
@@ -72,6 +83,10 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine)))
} catch {
+ case e: ClassNotFoundException if KNOWN_REMOVED_CLASSES.contains(e.getMessage) =>
+ // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1.
+ // It's safe since no place uses them.
+ logWarning(s"Dropped incompatible Structured Streaming log: $currentLine")
case jpe: JsonParseException =>
// We can only ignore exception from last line of the file that might be truncated
// the last entry may not be the very last line in the event log, but we treat it
@@ -102,4 +117,13 @@ private[spark] object ReplayListenerBus {
// utility filter that selects all event logs during replay
val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true }
+
+ /**
+ * Classes that were removed. Structured Streaming doesn't use them any more. However, parsing
+ * old json may fail and we can just ignore these failures.
+ */
+ val KNOWN_REMOVED_CLASSES = Set(
+ "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress",
+ "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated"
+ )
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index eeb7963c9e61..59680139e7af 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -17,8 +17,6 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.ListBuffer
-
import org.apache.spark.TaskState
import org.apache.spark.TaskState.TaskState
import org.apache.spark.annotation.DeveloperApi
@@ -54,7 +52,13 @@ class TaskInfo(
* accumulable to be updated multiple times in a single task or for two accumulables with the
* same name but different IDs to exist in a task.
*/
- val accumulables = ListBuffer[AccumulableInfo]()
+ def accumulables: Seq[AccumulableInfo] = _accumulables
+
+ private[this] var _accumulables: Seq[AccumulableInfo] = Nil
+
+ private[spark] def setAccumulables(newAccumulables: Seq[AccumulableInfo]): Unit = {
+ _accumulables = newAccumulables
+ }
/**
* The time when the task has completed successfully (including the time to remotely fetch
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 04d40e2907cf..368cd30a2e11 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -93,7 +93,7 @@ private[spark] class StandaloneSchedulerBackend(
val javaOpts = sparkJavaOpts ++ extraJavaOpts
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
- val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
+ val webUrl = sc.ui.map(_.webUrl).getOrElse("")
val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
// If we're using dynamic allocation, set our initial executor limit to 0 for now.
// ExecutorAllocationManager will send the real initial limit to the Master later.
@@ -103,8 +103,8 @@ private[spark] class StandaloneSchedulerBackend(
} else {
None
}
- val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
- appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit)
+ val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
+ webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit)
client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
client.start()
launcherBackend.setState(SparkAppHandle.State.SUBMITTED)
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
index f6a9f9c5573d..76af33c1a18d 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
@@ -21,7 +21,7 @@ import java.lang.annotation.Annotation
import java.lang.reflect.Type
import java.nio.charset.StandardCharsets
import java.text.SimpleDateFormat
-import java.util.{Calendar, SimpleTimeZone}
+import java.util.{Calendar, Locale, SimpleTimeZone}
import javax.ws.rs.Produces
import javax.ws.rs.core.{MediaType, MultivaluedMap}
import javax.ws.rs.ext.{MessageBodyWriter, Provider}
@@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{
private[spark] object JacksonMessageWriter {
def makeISODateFormat: SimpleDateFormat = {
- val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'")
+ val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US)
val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT"))
iso8601.setCalendar(cal)
iso8601
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala
index 0c71cd238222..d8d5e8958b23 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala
@@ -17,7 +17,7 @@
package org.apache.spark.status.api.v1
import java.text.{ParseException, SimpleDateFormat}
-import java.util.TimeZone
+import java.util.{Locale, TimeZone}
import javax.ws.rs.WebApplicationException
import javax.ws.rs.core.Response
import javax.ws.rs.core.Response.Status
@@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status
private[v1] class SimpleDateParam(val originalValue: String) {
val timestamp: Long = {
- val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz")
+ val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US)
try {
format.parse(originalValue).getTime()
} catch {
case _: ParseException =>
- val gmtDay = new SimpleDateFormat("yyyy-MM-dd")
+ val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US)
gmtDay.setTimeZone(TimeZone.getTimeZone("GMT"))
try {
gmtDay.parse(originalValue).getTime()
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index f631a047a707..b828532aba7a 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -82,7 +82,7 @@ private[spark] class SparkUI private (
initialize()
def getSparkUser: String = {
- environmentListener.systemProperties.toMap.get("user.name").getOrElse("")
+ environmentListener.systemProperties.toMap.getOrElse("user.name", "")
}
def getAppName: String = appName
@@ -94,16 +94,9 @@ private[spark] class SparkUI private (
/** Stop the server behind this web interface. Only valid after bind(). */
override def stop() {
super.stop()
- logInfo("Stopped Spark web UI at %s".format(appUIAddress))
+ logInfo(s"Stopped Spark web UI at $webUrl")
}
- /**
- * Return the application UI host:port. This does not include the scheme (http://).
- */
- private[spark] def appUIHostPort = publicHostName + ":" + boundPort
-
- private[spark] def appUIAddress = s"http://$appUIHostPort"
-
def getSparkUI(appId: String): Option[SparkUI] = {
if (appId == this.appId) Some(this) else None
}
@@ -136,7 +129,7 @@ private[spark] class SparkUI private (
private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String)
extends WebUITab(parent, prefix) {
- def appName: String = parent.getAppName
+ def appName: String = parent.appName
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index c0d1a2220f62..57f6f2f0a9be 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging {
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
- override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ override def initialValue(): SimpleDateFormat =
+ new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US)
}
def formatDate(date: Date): String = dateFormat.get.format(date)
@@ -170,6 +171,7 @@ private[spark] object UIUtils extends Logging {
+
}
def vizHeaderNodes: Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index a05e0efb7a3e..8c801558672f 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -56,8 +56,8 @@ private[spark] abstract class WebUI(
private val className = Utils.getFormattedClassName(this)
def getBasePath: String = basePath
- def getTabs: Seq[WebUITab] = tabs.toSeq
- def getHandlers: Seq[ServletContextHandler] = handlers.toSeq
+ def getTabs: Seq[WebUITab] = tabs
+ def getHandlers: Seq[ServletContextHandler] = handlers
def getSecurityManager: SecurityManager = securityManager
/** Attach a tab to this UI, along with all of its attached pages. */
@@ -133,7 +133,7 @@ private[spark] abstract class WebUI(
def initialize(): Unit
/** Bind to the HTTP server behind this web interface. */
- def bind() {
+ def bind(): Unit = {
assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!")
try {
val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0")
@@ -156,7 +156,7 @@ private[spark] abstract class WebUI(
def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1)
/** Stop the server behind this web interface. Only valid after bind(). */
- def stop() {
+ def stop(): Unit = {
assert(serverInfo.isDefined,
s"Attempted to stop $className before binding to a server!")
serverInfo.get.stop()
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index a0ef80d9bdae..c6a07445f2a3 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -48,6 +48,16 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage
}
}.map { thread =>
val threadId = thread.threadId
+ val blockedBy = thread.blockedByThreadId match {
+ case Some(blockedByThreadId) =>
+
+ Blocked by
+ Thread {thread.blockedByThreadId} {thread.blockedByLock}
+
+ case None => Text("")
+ }
+ val heldLocks = thread.holdingLocks.mkString(", ")
+
{threadId}
{thread.threadName}
{thread.threadState}
+ {blockedBy}{heldLocks}
{thread.stackTrace}
}
@@ -86,6 +97,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage
Thread ID
Thread Name
Thread State
+ Thread Locks
{dumpRows}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 173fc3cf31ce..50e8e2d19e15 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -289,8 +289,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
val startTime = listener.startTime
val endTime = listener.endTime
val activeJobs = listener.activeJobs.values.toSeq
- val completedJobs = listener.completedJobs.reverse.toSeq
- val failedJobs = listener.failedJobs.reverse.toSeq
+ val completedJobs = listener.completedJobs.reverse
+ val failedJobs = listener.failedJobs.reverse
val activeJobsTable =
jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index f4a04609c4c6..9ce8542f0279 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.{HashMap, LinkedHashMap}
import org.apache.spark.JobExecutionStatus
-import org.apache.spark.executor.{ShuffleReadMetrics, ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.executor._
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
import org.apache.spark.util.AccumulatorContext
import org.apache.spark.util.collection.OpenHashSet
@@ -147,9 +147,8 @@ private[spark] object UIData {
memoryBytesSpilled = m.memoryBytesSpilled,
diskBytesSpilled = m.diskBytesSpilled,
peakExecutionMemory = m.peakExecutionMemory,
- inputMetrics = InputMetricsUIData(m.inputMetrics.bytesRead, m.inputMetrics.recordsRead),
- outputMetrics =
- OutputMetricsUIData(m.outputMetrics.bytesWritten, m.outputMetrics.recordsWritten),
+ inputMetrics = InputMetricsUIData(m.inputMetrics),
+ outputMetrics = OutputMetricsUIData(m.outputMetrics),
shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics),
shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics))
}
@@ -171,9 +170,9 @@ private[spark] object UIData {
speculative = taskInfo.speculative
)
newTaskInfo.gettingResultTime = taskInfo.gettingResultTime
- newTaskInfo.accumulables ++= taskInfo.accumulables.filter {
+ newTaskInfo.setAccumulables(taskInfo.accumulables.filter {
accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)
- }
+ })
newTaskInfo.finishTime = taskInfo.finishTime
newTaskInfo.failed = taskInfo.failed
newTaskInfo
@@ -197,8 +196,32 @@ private[spark] object UIData {
shuffleWriteMetrics: ShuffleWriteMetricsUIData)
case class InputMetricsUIData(bytesRead: Long, recordsRead: Long)
+ object InputMetricsUIData {
+ def apply(metrics: InputMetrics): InputMetricsUIData = {
+ if (metrics.bytesRead == 0 && metrics.recordsRead == 0) {
+ EMPTY
+ } else {
+ new InputMetricsUIData(
+ bytesRead = metrics.bytesRead,
+ recordsRead = metrics.recordsRead)
+ }
+ }
+ private val EMPTY = InputMetricsUIData(0, 0)
+ }
case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long)
+ object OutputMetricsUIData {
+ def apply(metrics: OutputMetrics): OutputMetricsUIData = {
+ if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) {
+ EMPTY
+ } else {
+ new OutputMetricsUIData(
+ bytesWritten = metrics.bytesWritten,
+ recordsWritten = metrics.recordsWritten)
+ }
+ }
+ private val EMPTY = OutputMetricsUIData(0, 0)
+ }
case class ShuffleReadMetricsUIData(
remoteBlocksFetched: Long,
@@ -212,17 +235,30 @@ private[spark] object UIData {
object ShuffleReadMetricsUIData {
def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = {
- new ShuffleReadMetricsUIData(
- remoteBlocksFetched = metrics.remoteBlocksFetched,
- localBlocksFetched = metrics.localBlocksFetched,
- remoteBytesRead = metrics.remoteBytesRead,
- localBytesRead = metrics.localBytesRead,
- fetchWaitTime = metrics.fetchWaitTime,
- recordsRead = metrics.recordsRead,
- totalBytesRead = metrics.totalBytesRead,
- totalBlocksFetched = metrics.totalBlocksFetched
- )
+ if (
+ metrics.remoteBlocksFetched == 0 &&
+ metrics.localBlocksFetched == 0 &&
+ metrics.remoteBytesRead == 0 &&
+ metrics.localBytesRead == 0 &&
+ metrics.fetchWaitTime == 0 &&
+ metrics.recordsRead == 0 &&
+ metrics.totalBytesRead == 0 &&
+ metrics.totalBlocksFetched == 0) {
+ EMPTY
+ } else {
+ new ShuffleReadMetricsUIData(
+ remoteBlocksFetched = metrics.remoteBlocksFetched,
+ localBlocksFetched = metrics.localBlocksFetched,
+ remoteBytesRead = metrics.remoteBytesRead,
+ localBytesRead = metrics.localBytesRead,
+ fetchWaitTime = metrics.fetchWaitTime,
+ recordsRead = metrics.recordsRead,
+ totalBytesRead = metrics.totalBytesRead,
+ totalBlocksFetched = metrics.totalBlocksFetched
+ )
+ }
}
+ private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0)
}
case class ShuffleWriteMetricsUIData(
@@ -232,12 +268,17 @@ private[spark] object UIData {
object ShuffleWriteMetricsUIData {
def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = {
- new ShuffleWriteMetricsUIData(
- bytesWritten = metrics.bytesWritten,
- recordsWritten = metrics.recordsWritten,
- writeTime = metrics.writeTime
- )
+ if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) {
+ EMPTY
+ } else {
+ new ShuffleWriteMetricsUIData(
+ bytesWritten = metrics.bytesWritten,
+ recordsWritten = metrics.recordsWritten,
+ writeTime = metrics.writeTime
+ )
+ }
}
+ private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index c11eb3ffa460..4b4d2d10cbf8 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -107,20 +107,20 @@ private[spark] object JsonProtocol {
def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted): JValue = {
val stageInfo = stageInfoToJson(stageSubmitted.stageInfo)
val properties = propertiesToJson(stageSubmitted.properties)
- ("Event" -> Utils.getFormattedClassName(stageSubmitted)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) ~
("Stage Info" -> stageInfo) ~
("Properties" -> properties)
}
def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = {
val stageInfo = stageInfoToJson(stageCompleted.stageInfo)
- ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) ~
("Stage Info" -> stageInfo)
}
def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = {
val taskInfo = taskStart.taskInfo
- ("Event" -> Utils.getFormattedClassName(taskStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) ~
("Stage ID" -> taskStart.stageId) ~
("Stage Attempt ID" -> taskStart.stageAttemptId) ~
("Task Info" -> taskInfoToJson(taskInfo))
@@ -128,7 +128,7 @@ private[spark] object JsonProtocol {
def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = {
val taskInfo = taskGettingResult.taskInfo
- ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) ~
("Task Info" -> taskInfoToJson(taskInfo))
}
@@ -137,7 +137,7 @@ private[spark] object JsonProtocol {
val taskInfo = taskEnd.taskInfo
val taskMetrics = taskEnd.taskMetrics
val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing
- ("Event" -> Utils.getFormattedClassName(taskEnd)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) ~
("Stage ID" -> taskEnd.stageId) ~
("Stage Attempt ID" -> taskEnd.stageAttemptId) ~
("Task Type" -> taskEnd.taskType) ~
@@ -148,7 +148,7 @@ private[spark] object JsonProtocol {
def jobStartToJson(jobStart: SparkListenerJobStart): JValue = {
val properties = propertiesToJson(jobStart.properties)
- ("Event" -> Utils.getFormattedClassName(jobStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) ~
("Job ID" -> jobStart.jobId) ~
("Submission Time" -> jobStart.time) ~
("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0
@@ -158,7 +158,7 @@ private[spark] object JsonProtocol {
def jobEndToJson(jobEnd: SparkListenerJobEnd): JValue = {
val jobResult = jobResultToJson(jobEnd.jobResult)
- ("Event" -> Utils.getFormattedClassName(jobEnd)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobEnd) ~
("Job ID" -> jobEnd.jobId) ~
("Completion Time" -> jobEnd.time) ~
("Job Result" -> jobResult)
@@ -170,7 +170,7 @@ private[spark] object JsonProtocol {
val sparkProperties = mapToJson(environmentDetails("Spark Properties").toMap)
val systemProperties = mapToJson(environmentDetails("System Properties").toMap)
val classpathEntries = mapToJson(environmentDetails("Classpath Entries").toMap)
- ("Event" -> Utils.getFormattedClassName(environmentUpdate)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.environmentUpdate) ~
("JVM Information" -> jvmInformation) ~
("Spark Properties" -> sparkProperties) ~
("System Properties" -> systemProperties) ~
@@ -179,7 +179,7 @@ private[spark] object JsonProtocol {
def blockManagerAddedToJson(blockManagerAdded: SparkListenerBlockManagerAdded): JValue = {
val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId)
- ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~
("Block Manager ID" -> blockManagerId) ~
("Maximum Memory" -> blockManagerAdded.maxMem) ~
("Timestamp" -> blockManagerAdded.time)
@@ -187,18 +187,18 @@ private[spark] object JsonProtocol {
def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = {
val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId)
- ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerRemoved) ~
("Block Manager ID" -> blockManagerId) ~
("Timestamp" -> blockManagerRemoved.time)
}
def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = {
- ("Event" -> Utils.getFormattedClassName(unpersistRDD)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.unpersistRDD) ~
("RDD ID" -> unpersistRDD.rddId)
}
def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = {
- ("Event" -> Utils.getFormattedClassName(applicationStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationStart) ~
("App Name" -> applicationStart.appName) ~
("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~
("Timestamp" -> applicationStart.time) ~
@@ -208,33 +208,33 @@ private[spark] object JsonProtocol {
}
def applicationEndToJson(applicationEnd: SparkListenerApplicationEnd): JValue = {
- ("Event" -> Utils.getFormattedClassName(applicationEnd)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationEnd) ~
("Timestamp" -> applicationEnd.time)
}
def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = {
- ("Event" -> Utils.getFormattedClassName(executorAdded)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorAdded) ~
("Timestamp" -> executorAdded.time) ~
("Executor ID" -> executorAdded.executorId) ~
("Executor Info" -> executorInfoToJson(executorAdded.executorInfo))
}
def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = {
- ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorRemoved) ~
("Timestamp" -> executorRemoved.time) ~
("Executor ID" -> executorRemoved.executorId) ~
("Removed Reason" -> executorRemoved.reason)
}
def logStartToJson(logStart: SparkListenerLogStart): JValue = {
- ("Event" -> Utils.getFormattedClassName(logStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.logStart) ~
("Spark Version" -> SPARK_VERSION)
}
def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = {
val execId = metricsUpdate.execId
val accumUpdates = metricsUpdate.accumUpdates
- ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~
("Executor ID" -> execId) ~
("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) =>
("Task ID" -> taskId) ~
@@ -485,7 +485,7 @@ private[spark] object JsonProtocol {
* JSON deserialization methods for SparkListenerEvents |
* ---------------------------------------------------- */
- def sparkEventFromJson(json: JValue): SparkListenerEvent = {
+ private object SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES {
val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted)
val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted)
val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart)
@@ -503,6 +503,10 @@ private[spark] object JsonProtocol {
val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved)
val logStart = Utils.getFormattedClassName(SparkListenerLogStart)
val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate)
+ }
+
+ def sparkEventFromJson(json: JValue): SparkListenerEvent = {
+ import SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES._
(json \ "Event").extract[String] match {
case `stageSubmitted` => stageSubmittedFromJson(json)
@@ -540,7 +544,8 @@ private[spark] object JsonProtocol {
def taskStartFromJson(json: JValue): SparkListenerTaskStart = {
val stageId = (json \ "Stage ID").extract[Int]
- val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
+ val stageAttemptId =
+ Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val taskInfo = taskInfoFromJson(json \ "Task Info")
SparkListenerTaskStart(stageId, stageAttemptId, taskInfo)
}
@@ -552,7 +557,8 @@ private[spark] object JsonProtocol {
def taskEndFromJson(json: JValue): SparkListenerTaskEnd = {
val stageId = (json \ "Stage ID").extract[Int]
- val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
+ val stageAttemptId =
+ Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val taskType = (json \ "Task Type").extract[String]
val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason")
val taskInfo = taskInfoFromJson(json \ "Task Info")
@@ -662,20 +668,22 @@ private[spark] object JsonProtocol {
def stageInfoFromJson(json: JValue): StageInfo = {
val stageId = (json \ "Stage ID").extract[Int]
- val attemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
+ val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val stageName = (json \ "Stage Name").extract[String]
val numTasks = (json \ "Number of Tasks").extract[Int]
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
val parentIds = Utils.jsonOption(json \ "Parent IDs")
.map { l => l.extract[List[JValue]].map(_.extract[Int]) }
.getOrElse(Seq.empty)
- val details = (json \ "Details").extractOpt[String].getOrElse("")
+ val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("")
val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
- val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match {
- case Some(values) => values.map(accumulableInfoFromJson)
- case None => Seq[AccumulableInfo]()
+ val accumulatedValues = {
+ Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match {
+ case Some(values) => values.map(accumulableInfoFromJson)
+ case None => Seq[AccumulableInfo]()
+ }
}
val stageInfo = new StageInfo(
@@ -692,17 +700,17 @@ private[spark] object JsonProtocol {
def taskInfoFromJson(json: JValue): TaskInfo = {
val taskId = (json \ "Task ID").extract[Long]
val index = (json \ "Index").extract[Int]
- val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1)
+ val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1)
val launchTime = (json \ "Launch Time").extract[Long]
- val executorId = (json \ "Executor ID").extract[String]
- val host = (json \ "Host").extract[String]
+ val executorId = (json \ "Executor ID").extract[String].intern()
+ val host = (json \ "Host").extract[String].intern()
val taskLocality = TaskLocality.withName((json \ "Locality").extract[String])
- val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false)
+ val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean])
val gettingResultTime = (json \ "Getting Result Time").extract[Long]
val finishTime = (json \ "Finish Time").extract[Long]
val failed = (json \ "Failed").extract[Boolean]
- val killed = (json \ "Killed").extractOpt[Boolean].getOrElse(false)
- val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
+ val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean])
+ val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match {
case Some(values) => values.map(accumulableInfoFromJson)
case None => Seq[AccumulableInfo]()
}
@@ -713,18 +721,19 @@ private[spark] object JsonProtocol {
taskInfo.finishTime = finishTime
taskInfo.failed = failed
taskInfo.killed = killed
- accumulables.foreach { taskInfo.accumulables += _ }
+ taskInfo.setAccumulables(accumulables)
taskInfo
}
def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
val id = (json \ "ID").extract[Long]
- val name = (json \ "Name").extractOpt[String]
+ val name = Utils.jsonOption(json \ "Name").map(_.extract[String])
val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) }
val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
- val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false)
- val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false)
- val metadata = (json \ "Metadata").extractOpt[String]
+ val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean])
+ val countFailedValues =
+ Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean])
+ val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String])
new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata)
}
@@ -782,9 +791,11 @@ private[spark] object JsonProtocol {
readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int])
readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int])
readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long])
- readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L))
+ readMetrics.incLocalBytesRead(
+ Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L))
readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long])
- readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L))
+ readMetrics.incRecordsRead(
+ Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L))
metrics.mergeShuffleReadMetrics()
}
@@ -793,8 +804,8 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson =>
val writeMetrics = metrics.shuffleWriteMetrics
writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long])
- writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written")
- .extractOpt[Long].getOrElse(0L))
+ writeMetrics.incRecordsWritten(
+ Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L))
writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long])
}
@@ -802,14 +813,16 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Output Metrics").foreach { outJson =>
val outputMetrics = metrics.outputMetrics
outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long])
- outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L))
+ outputMetrics.setRecordsWritten(
+ Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L))
}
// Input metrics
Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
val inputMetrics = metrics.inputMetrics
inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
- inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L))
+ inputMetrics.incRecordsRead(
+ Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L))
}
// Updated blocks
@@ -824,7 +837,7 @@ private[spark] object JsonProtocol {
metrics
}
- def taskEndReasonFromJson(json: JValue): TaskEndReason = {
+ private object TASK_END_REASON_FORMATTED_CLASS_NAMES {
val success = Utils.getFormattedClassName(Success)
val resubmitted = Utils.getFormattedClassName(Resubmitted)
val fetchFailed = Utils.getFormattedClassName(FetchFailed)
@@ -834,6 +847,10 @@ private[spark] object JsonProtocol {
val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied)
val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure)
val unknownReason = Utils.getFormattedClassName(UnknownReason)
+ }
+
+ def taskEndReasonFromJson(json: JValue): TaskEndReason = {
+ import TASK_END_REASON_FORMATTED_CLASS_NAMES._
(json \ "Reason").extract[String] match {
case `success` => Success
@@ -850,7 +867,8 @@ private[spark] object JsonProtocol {
val className = (json \ "Class Name").extract[String]
val description = (json \ "Description").extract[String]
val stackTrace = stackTraceFromJson(json \ "Stack Trace")
- val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull
+ val fullStackTrace =
+ Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull
// Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
.map(_.extract[List[JValue]].map(accumulableInfoFromJson))
@@ -885,15 +903,19 @@ private[spark] object JsonProtocol {
if (json == JNothing) {
return null
}
- val executorId = (json \ "Executor ID").extract[String]
- val host = (json \ "Host").extract[String]
+ val executorId = (json \ "Executor ID").extract[String].intern()
+ val host = (json \ "Host").extract[String].intern()
val port = (json \ "Port").extract[Int]
BlockManagerId(executorId, host, port)
}
- def jobResultFromJson(json: JValue): JobResult = {
+ private object JOB_RESULT_FORMATTED_CLASS_NAMES {
val jobSucceeded = Utils.getFormattedClassName(JobSucceeded)
val jobFailed = Utils.getFormattedClassName(JobFailed)
+ }
+
+ def jobResultFromJson(json: JValue): JobResult = {
+ import JOB_RESULT_FORMATTED_CLASS_NAMES._
(json \ "Result").extract[String] match {
case `jobSucceeded` => JobSucceeded
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
index d4e0ad93b966..b1217980faf1 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
@@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace(
threadId: Long,
threadName: String,
threadState: Thread.State,
- stackTrace: String)
+ stackTrace: String,
+ blockedByThreadId: Option[Long],
+ blockedByLock: String,
+ holdingLocks: Seq[String])
+
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 6027b07c0fee..1de66af632a8 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
import java.io._
-import java.lang.management.ManagementFactory
+import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo}
import java.net._
import java.nio.ByteBuffer
import java.nio.channels.Channels
@@ -2096,15 +2096,41 @@ private[spark] object Utils extends Logging {
}
}
+ private implicit class Lock(lock: LockInfo) {
+ def lockString: String = {
+ lock match {
+ case monitor: MonitorInfo =>
+ s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})"
+ case _ =>
+ s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})"
+ }
+ }
+ }
+
/** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */
def getThreadDump(): Array[ThreadStackTrace] = {
// We need to filter out null values here because dumpAllThreads() may return null array
// elements for threads that are dead / don't exist.
val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
- val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
- ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
- threadInfo.getThreadState, stackTrace)
+ val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
+ val stackTrace = threadInfo.getStackTrace.map { frame =>
+ monitors.get(frame) match {
+ case Some(monitor) =>
+ monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
+ case None =>
+ frame.toString
+ }
+ }.mkString("\n")
+
+ // use a set to dedup re-entrant locks that are held at multiple places
+ val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString)
+ ++ threadInfo.getLockedMonitors.map(_.lockString)
+ ).toSet
+
+ ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState,
+ stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
+ Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq)
}
}
@@ -2513,6 +2539,8 @@ private[util] object CallerContext extends Logging {
val callerContextSupported: Boolean = {
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
try {
+ // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
+ // master Maven build, so do not use it before resolving SPARK-17714.
// scalastyle:off classforname
Class.forName("org.apache.hadoop.ipc.CallerContext")
Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
@@ -2578,6 +2606,8 @@ private[spark] class CallerContext(
def setCurrentContext(): Unit = {
if (CallerContext.callerContextSupported) {
try {
+ // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
+ // master Maven build, so do not use it before resolving SPARK-17714.
// scalastyle:off classforname
val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index 6b74a29aceda..bcb95b416dd2 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -140,16 +140,16 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k.equals(curKey)) {
- val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
- data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
- return newValue
- } else if (curKey.eq(null)) {
+ if (curKey.eq(null)) {
val newValue = updateFunc(false, null.asInstanceOf[V])
data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
incrementSize()
return newValue
+ } else if (k.eq(curKey) || k.equals(curKey)) {
+ val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ return newValue
} else {
val delta = i
pos = (pos + delta) & mask
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 0f6a425e3db9..60f6f537c1d5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -48,7 +48,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
require(initialCapacity <= OpenHashSet.MAX_CAPACITY,
s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements")
- require(initialCapacity >= 1, "Invalid initial capacity")
+ require(initialCapacity >= 0, "Invalid initial capacity")
require(loadFactor < 1.0, "Load factor must be less than 1.0")
require(loadFactor > 0.0, "Load factor must be greater than 0.0")
@@ -271,8 +271,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt()
private def nextPowerOf2(n: Int): Int = {
- val highBit = Integer.highestOneBit(n)
- if (highBit == n) n else highBit << 1
+ if (n == 0) {
+ 1
+ } else {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
index 5c4238c0381a..1f263df57c85 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util.logging
import java.text.SimpleDateFormat
-import java.util.Calendar
+import java.util.{Calendar, Locale}
import org.apache.spark.internal.Logging
@@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy(
}
@volatile private var nextRolloverTime = calculateNextRolloverTime()
- private val formatter = new SimpleDateFormat(rollingFileSuffixPattern)
+ private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US)
/** Should rollover if current time has exceeded next rollover time */
def shouldRollover(bytesToBeWritten: Long): Boolean = {
@@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy(
}
@volatile private var bytesWrittenSinceRollover = 0L
- val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS")
+ val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US)
/** Should rollover if the next set of bytes is going to exceed the size limit */
def shouldRollover(bytesToBeWritten: Long): Boolean = {
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 6724af952505..0f78871ed35a 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -44,7 +44,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[So
{
implicit val defaultTimeout = timeout(10000 millis)
val conf = new SparkConf()
- .setMaster("local[2]")
+ .setMaster("local[4]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
.set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
@@ -232,7 +232,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
// Verify that checkpoints are NOT cleaned up if the config is not enabled
sc.stop()
val conf = new SparkConf()
- .setMaster("local[2]")
+ .setMaster("local[4]")
.setAppName("cleanupCheckpoint")
.set("spark.cleaner.referenceTracking.cleanCheckpoints", "false")
sc = new SparkContext(conf)
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index cc52bb1d23cd..89f0b1cb5b56 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -58,10 +58,15 @@ class FileSuite extends SparkFunSuite with LocalSparkContext {
nums.saveAsTextFile(outputDir)
// Read the plain text file and check it's OK
val outputFile = new File(outputDir, "part-00000")
- val content = Source.fromFile(outputFile).mkString
- assert(content === "1\n2\n3\n4\n")
- // Also try reading it in as a text file RDD
- assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
+ val bufferSrc = Source.fromFile(outputFile)
+ Utils.tryWithSafeFinally {
+ val content = bufferSrc.mkString
+ assert(content === "1\n2\n3\n4\n")
+ // Also try reading it in as a text file RDD
+ assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
+ } {
+ bufferSrc.close()
+ }
}
test("text files (compressed)") {
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 915d7a1b8b16..5457a066d3c0 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -67,7 +67,7 @@ class HeartbeatReceiverSuite
override def beforeEach(): Unit = {
super.beforeEach()
val conf = new SparkConf()
- .setMaster("local[2]")
+ .setMaster("local[4]")
.setAppName("test")
.set("spark.dynamicAllocation.testing", "true")
sc = spy(new SparkContext(conf))
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index a3490fc79e45..5b89eaae032a 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -47,7 +47,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
test("local mode, FIFO scheduler") {
val conf = new SparkConf().set("spark.scheduler.mode", "FIFO")
- sc = new SparkContext("local[2]", "test", conf)
+ sc = new SparkContext("local[4]", "test", conf)
testCount()
testTake()
// Make sure we can still launch tasks.
@@ -58,7 +58,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
val conf = new SparkConf().set("spark.scheduler.mode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
conf.set("spark.scheduler.allocation.file", xmlPath)
- sc = new SparkContext("local[2]", "test", conf)
+ sc = new SparkContext("local[4]", "test", conf)
testCount()
testTake()
// Make sure we can still launch tasks.
@@ -115,7 +115,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("job group") {
- sc = new SparkContext("local[2]", "test")
+ sc = new SparkContext("local[4]", "test")
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
@@ -145,7 +145,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("inherited job group (SPARK-6629)") {
- sc = new SparkContext("local[2]", "test")
+ sc = new SparkContext("local[4]", "test")
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
@@ -180,7 +180,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("job group with interruption") {
- sc = new SparkContext("local[2]", "test")
+ sc = new SparkContext("local[4]", "test")
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
@@ -215,7 +215,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
- sc = new SparkContext("local[2]", "test")
+ sc = new SparkContext("local[4]", "test")
sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem1.release()
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 83906cff123b..21b2726d7e1d 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -132,8 +132,8 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
test("SparkContext property overriding") {
val conf = new SparkConf(false).setMaster("local").setAppName("My app")
- sc = new SparkContext("local[2]", "My other app", conf)
- assert(sc.master === "local[2]")
+ sc = new SparkContext("local[4]", "My other app", conf)
+ assert(sc.master === "local[4]")
assert(sc.appName === "My other app")
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
index 13cba94578a6..005587051b6a 100644
--- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkFunSuite
import org.apache.spark.api.r.RUtils
import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate
-import org.apache.spark.util.ResetSystemProperties
+import org.apache.spark.util.{ResetSystemProperties, Utils}
class RPackageUtilsSuite
extends SparkFunSuite
@@ -74,9 +74,13 @@ class RPackageUtilsSuite
val deps = Seq(dep1, dep2).mkString(",")
IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo =>
val jars = Seq(main, dep1, dep2).map(c => new JarFile(getJarPath(c, new File(new URI(repo)))))
- assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code")
- assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code")
- assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code")
+ Utils.tryWithSafeFinally {
+ assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code")
+ assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code")
+ assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code")
+ } {
+ jars.foreach(_.close())
+ }
}
}
@@ -131,7 +135,7 @@ class RPackageUtilsSuite
test("SparkR zipping works properly") {
val tempDir = Files.createTempDir()
- try {
+ Utils.tryWithSafeFinally {
IvyTestUtils.writeFile(tempDir, "test.R", "abc")
val fakeSparkRDir = new File(tempDir, "SparkR")
assert(fakeSparkRDir.mkdirs())
@@ -144,14 +148,19 @@ class RPackageUtilsSuite
IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc")
val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip")
assert(finalZip.exists())
- val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq
- assert(entries.contains("/test.R"))
- assert(entries.contains("/SparkR/abc.R"))
- assert(entries.contains("/SparkR/DESCRIPTION"))
- assert(!entries.contains("/package.zip"))
- assert(entries.contains("/packageTest/def.R"))
- assert(entries.contains("/packageTest/DESCRIPTION"))
- } finally {
+ val zipFile = new ZipFile(finalZip)
+ Utils.tryWithSafeFinally {
+ val entries = zipFile.entries().asScala.map(_.getName).toSeq
+ assert(entries.contains("/test.R"))
+ assert(entries.contains("/SparkR/abc.R"))
+ assert(entries.contains("/SparkR/DESCRIPTION"))
+ assert(!entries.contains("/package.zip"))
+ assert(entries.contains("/packageTest/def.R"))
+ assert(entries.contains("/packageTest/DESCRIPTION"))
+ } {
+ zipFile.close()
+ }
+ } {
FileUtils.deleteDirectory(tempDir)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index a5eda7b5a5a7..2c41c432d1fe 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -449,8 +449,14 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream)
val bstream = new BufferedOutputStream(cstream)
if (isNewFormat) {
- EventLoggingListener.initEventLog(new FileOutputStream(file))
+ val newFormatStream = new FileOutputStream(file)
+ Utils.tryWithSafeFinally {
+ EventLoggingListener.initEventLog(newFormatStream)
+ } {
+ newFormatStream.close()
+ }
}
+
val writer = new OutputStreamWriter(bstream, StandardCharsets.UTF_8)
Utils.tryWithSafeFinally {
events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n"))
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index a595bc174a31..715811a46f42 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -29,6 +29,8 @@ import com.codahale.metrics.Counter
import com.google.common.io.{ByteStreams, Files}
import org.apache.commons.io.{FileUtils, IOUtils}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.eclipse.jetty.proxy.ProxyServlet
+import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods._
@@ -258,8 +260,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND)
}
- test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
- val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase")
+ test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase")
val page = new HistoryPage(server)
val request = mock[HttpServletRequest]
@@ -267,7 +268,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
// when
System.setProperty("spark.ui.proxyBase", uiRoot)
val response = page.render(request)
- System.setProperty("spark.ui.proxyBase", Option(proxyBaseBeforeTest).getOrElse(""))
// then
val urls = response \\ "@href" map (_.toString)
@@ -275,6 +275,80 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
all (siteRelativeLinks) should startWith (uiRoot)
}
+ test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
+ val uiRoot = "/testwebproxybase"
+ System.setProperty("spark.ui.proxyBase", uiRoot)
+
+ server.stop()
+
+ val conf = new SparkConf()
+ .set("spark.history.fs.logDirectory", logDir)
+ .set("spark.history.fs.update.interval", "0")
+ .set("spark.testing", "true")
+
+ provider = new FsHistoryProvider(conf)
+ provider.checkForLogs()
+ val securityManager = new SecurityManager(conf)
+
+ server = new HistoryServer(conf, provider, securityManager, 18080)
+ server.initialize()
+ server.bind()
+
+ val port = server.boundPort
+
+ val servlet = new ProxyServlet {
+ override def rewriteTarget(request: HttpServletRequest): String = {
+ // servlet acts like a proxy that redirects calls made on
+ // spark.ui.proxyBase context path to the normal servlet handlers operating off "/"
+ val sb = request.getRequestURL()
+
+ if (request.getQueryString() != null) {
+ sb.append(s"?${request.getQueryString()}")
+ }
+
+ val proxyidx = sb.indexOf(uiRoot)
+ sb.delete(proxyidx, proxyidx + uiRoot.length).toString
+ }
+ }
+
+ val contextHandler = new ServletContextHandler
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(uiRoot)
+ contextHandler.addServlet(holder, "/")
+ server.attachHandler(contextHandler)
+
+ implicit val webDriver: WebDriver = new HtmlUnitDriver(true) {
+ getWebClient.getOptions.setThrowExceptionOnScriptError(false)
+ }
+
+ try {
+ val url = s"http://localhost:$port"
+
+ go to s"$url$uiRoot"
+
+ // expect the ajax call to finish in 5 seconds
+ implicitlyWait(org.scalatest.time.Span(5, org.scalatest.time.Seconds))
+
+ // once this findAll call returns, we know the ajax load of the table completed
+ findAll(ClassNameQuery("odd"))
+
+ val links = findAll(TagNameQuery("a"))
+ .map(_.attribute("href"))
+ .filter(_.isDefined)
+ .map(_.get)
+ .filter(_.startsWith(url)).toList
+
+ // there are atleast some URL links that were generated via javascript,
+ // and they all contain the spark.ui.proxyBase (uiRoot)
+ links.length should be > 4
+ all(links) should startWith(url + uiRoot)
+ } finally {
+ contextHandler.stop()
+ quit()
+ }
+
+ }
+
test("incomplete apps get refreshed") {
implicit val webDriver: WebDriver = new HtmlUnitDriver
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 58664e77d24a..ef5845a77c11 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -36,7 +36,7 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
override def beforeAll() {
super.beforeAll()
- sc = new SparkContext("local[2]", "test")
+ sc = new SparkContext("local[4]", "test")
}
override def afterAll() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
index 2802cd975292..5ff61b35c8bc 100644
--- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
@@ -28,7 +28,7 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext {
override def beforeEach(): Unit = {
super.beforeEach()
- sc = new SparkContext("local[2]", "test")
+ sc = new SparkContext("local[4]", "test")
}
test("transform storage level") {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index b0d69de6e2ef..02df157be377 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -516,10 +516,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored")
/*
- Check that configurable formats get configured:
- ConfigTestFormat throws an exception if we try to write
- to it when setConf hasn't been called first.
- Assertion is in ConfigTestFormat.getRecordWriter.
+ * Check that configurable formats get configured:
+ * ConfigTestFormat throws an exception if we try to write
+ * to it when setConf hasn't been called first.
+ * Assertion is in ConfigTestFormat.getRecordWriter.
*/
pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored")
}
@@ -544,7 +544,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
val e = intercept[SparkException] {
pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored")
}
- assert(e.getMessage contains "failed to write")
+ assert(e.getCause.getMessage contains "failed to write")
assert(FakeWriterWithCallback.calledBy === "write,callback,close")
assert(FakeWriterWithCallback.exception != null, "exception should be captured")
@@ -725,8 +725,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
}
/*
- These classes are fakes for testing
- "saveNewAPIHadoopFile should call setConf if format is configurable".
+ These classes are fakes for testing saveAsHadoopFile/saveNewAPIHadoopFile.
Unfortunately, they have to be top level classes, and not defined in
the test method, because otherwise Scala won't generate no-args constructors
and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index acdf21df9a16..aa0705987d83 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}
+
+ test("isInRPCThread") {
+ val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case m => context.reply(rpcEnv.isInRPCThread)
+ }
+ })
+ assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
+ assert(env.isInRPCThread === false)
+ env.stop(rpcEndpointRef)
+ }
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index 7f4859206e25..8a5ec37eeb66 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -202,8 +202,6 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
// Make sure expected events exist in the log file.
val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem)
- val logStart = SparkListenerLogStart(SPARK_VERSION)
- val lines = readLines(logData)
val eventSet = mutable.Set(
SparkListenerApplicationStart,
SparkListenerBlockManagerAdded,
@@ -216,19 +214,25 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
SparkListenerTaskStart,
SparkListenerTaskEnd,
SparkListenerApplicationEnd).map(Utils.getFormattedClassName)
- lines.foreach { line =>
- eventSet.foreach { event =>
- if (line.contains(event)) {
- val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line))
- val eventType = Utils.getFormattedClassName(parsedEvent)
- if (eventType == event) {
- eventSet.remove(event)
+ Utils.tryWithSafeFinally {
+ val logStart = SparkListenerLogStart(SPARK_VERSION)
+ val lines = readLines(logData)
+ lines.foreach { line =>
+ eventSet.foreach { event =>
+ if (line.contains(event)) {
+ val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line))
+ val eventType = Utils.getFormattedClassName(parsedEvent)
+ if (eventType == event) {
+ eventSet.remove(event)
+ }
}
}
}
+ assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart)
+ assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq)
+ } {
+ logData.close()
}
- assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart)
- assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq)
}
private def readLines(in: InputStream): Seq[String] = {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 9e472f900b65..ee95e4ff7dbc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -183,9 +183,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
// ensure we reset the classloader after the test completes
val originalClassLoader = Thread.currentThread.getContextClassLoader
- try {
+ val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
+ Utils.tryWithSafeFinally {
// load the exception from the jar
- val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
loader.addURL(jarFile.toURI.toURL)
Thread.currentThread().setContextClassLoader(loader)
val excClass: Class[_] = Utils.classForName("repro.MyException")
@@ -209,8 +209,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
- } finally {
+ } {
Thread.currentThread.setContextClassLoader(originalClassLoader)
+ loader.close()
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index e5d408a16736..f4786e3931c9 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -473,7 +473,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync()
eventually(timeout(5 seconds), interval(50 milliseconds)) {
val url = new URL(
- sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0")
+ sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
@@ -486,7 +486,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync()
eventually(timeout(5 seconds), interval(50 milliseconds)) {
val url = new URL(
- sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0")
+ sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
@@ -620,7 +620,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
test("live UI json application list") {
withSpark(newSparkContext()) { sc =>
val appListRawJson = HistoryServerSuite.getUrl(new URL(
- sc.ui.get.appUIAddress + "/api/v1/applications"))
+ sc.ui.get.webUrl + "/api/v1/applications"))
val appListJsonAst = JsonMethods.parse(appListRawJson)
appListJsonAst.children.length should be (1)
val attempts = (appListJsonAst \ "attempts").children
@@ -640,7 +640,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity)
rdd.count()
- val stage0 = Source.fromURL(sc.ui.get.appUIAddress +
+ val stage0 = Source.fromURL(sc.ui.get.webUrl +
"/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString
assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " +
"label="Stage 0";\n subgraph "))
@@ -651,7 +651,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
assert(stage0.contains("{\n label="groupBy";\n " +
"2 [label="MapPartitionsRDD [2]"))
- val stage1 = Source.fromURL(sc.ui.get.appUIAddress +
+ val stage1 = Source.fromURL(sc.ui.get.webUrl +
"/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString
assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " +
"label="Stage 1";\n subgraph "))
@@ -662,7 +662,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
assert(stage1.contains("{\n label="groupBy";\n " +
"5 [label="MapPartitionsRDD [5]"))
- val stage2 = Source.fromURL(sc.ui.get.appUIAddress +
+ val stage2 = Source.fromURL(sc.ui.get.webUrl +
"/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString
assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " +
"label="Stage 2";\n subgraph "))
@@ -687,7 +687,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}
def goToUi(ui: SparkUI, path: String): Unit = {
- go to (ui.appUIAddress.stripSuffix("/") + path)
+ go to (ui.webUrl.stripSuffix("/") + path)
}
def parseDate(json: JValue): Long = {
@@ -699,6 +699,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}
def apiUrl(ui: SparkUI, path: String): URL = {
- new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path)
+ new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path)
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 4abcfb7e5191..68c7657cb315 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -66,7 +66,7 @@ class UISuite extends SparkFunSuite {
withSpark(newSparkContext()) { sc =>
// test if the ui is visible, and all the expected tabs are visible
eventually(timeout(10 seconds), interval(50 milliseconds)) {
- val html = Source.fromURL(sc.ui.get.appUIAddress).mkString
+ val html = Source.fromURL(sc.ui.get.webUrl).mkString
assert(!html.contains("random data that should not be present"))
assert(html.toLowerCase.contains("stages"))
assert(html.toLowerCase.contains("storage"))
@@ -176,19 +176,18 @@ class UISuite extends SparkFunSuite {
}
}
- test("verify appUIAddress contains the scheme") {
+ test("verify webUrl contains the scheme") {
withSpark(newSparkContext()) { sc =>
val ui = sc.ui.get
- val uiAddress = ui.appUIAddress
- val uiHostPort = ui.appUIHostPort
- assert(uiAddress.equals("http://" + uiHostPort))
+ val uiAddress = ui.webUrl
+ assert(uiAddress.startsWith("http://") || uiAddress.startsWith("https://"))
}
}
- test("verify appUIAddress contains the port") {
+ test("verify webUrl contains the port") {
withSpark(newSparkContext()) { sc =>
val ui = sc.ui.get
- val splitUIAddress = ui.appUIAddress.split(':')
+ val splitUIAddress = ui.webUrl.split(':')
val boundPort = ui.boundPort
assert(splitUIAddress(2).toInt == boundPort)
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 8418fa74d2c6..da853f1be8b9 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -403,7 +403,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
internal = false,
countFailedValues = false,
metadata = None)
- taskInfo.accumulables ++= Seq(internalAccum, sqlAccum, userAccum)
+ taskInfo.setAccumulables(List(internalAccum, sqlAccum, userAccum))
val newTaskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo)
assert(newTaskInfo.accumulables === Seq(userAccum))
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index d5146d70ebaa..85da79180fd0 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -788,11 +788,8 @@ private[spark] object JsonProtocolSuite extends Assertions {
private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = {
val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL,
speculative)
- val (acc1, acc2, acc3) =
- (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true))
- taskInfo.accumulables += acc1
- taskInfo.accumulables += acc2
- taskInfo.accumulables += acc3
+ taskInfo.setAccumulables(
+ List(makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true)))
taskInfo
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 15ef32f21d90..feacfb7642f2 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -264,7 +264,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
val hour = minute * 60
def str: (Long) => String = Utils.msDurationToString(_)
- val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator()
+ val sep = new DecimalFormatSymbols(Locale.US).getDecimalSeparator
assert(str(123) === "123 ms")
assert(str(second) === "1" + sep + "0 s")
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index 3066e9996abd..335ecb9320ab 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -49,9 +49,6 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers {
intercept[IllegalArgumentException] {
new OpenHashMap[String, Int](-1)
}
- intercept[IllegalArgumentException] {
- new OpenHashMap[String, String](0)
- }
}
test("primitive value") {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index 2607a543dd61..210bc5c09974 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -176,4 +176,9 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers {
assert(set.size === 1000)
assert(set.capacity > 1000)
}
+
+ test("SPARK-18200 Support zero as an initial set size") {
+ val set = new OpenHashSet[Long](0)
+ assert(set.size === 0)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index 508e737b725b..f5ee428020fd 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -49,9 +49,6 @@ class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers {
intercept[IllegalArgumentException] {
new PrimitiveKeyOpenHashMap[Int, Int](-1)
}
- intercept[IllegalArgumentException] {
- new PrimitiveKeyOpenHashMap[Int, Int](0)
- }
}
test("basic operations") {
diff --git a/docs/building-spark.md b/docs/building-spark.md
index ebe46a42a15c..2b404bd3e116 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -13,6 +13,7 @@ redirect_from: "building-with-maven.html"
The Maven-based build is the build of reference for Apache Spark.
Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+.
+Note that support for Java 7 is deprecated as of Spark 2.0.0 and may be removed in Spark 2.2.0.
### Setting up Maven's Memory Usage
@@ -79,6 +80,9 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro
spark.kryo.referenceTrackingspark.serializerspark.files.maxPartitionBytesspark.files.openCostInBytesspark.hadoop.cloneConfspark.r.shell.command is used for sparkR shell while spark.r.driver.command is used for running R script.
spark.r.backendConnectionTimeoutspark.r.heartBeatIntervalspark.mesos.fetcherCache.enablefalse