diff --git a/LICENSE b/LICENSE index c21032a1fd27..66a2e8f13295 100644 --- a/LICENSE +++ b/LICENSE @@ -249,11 +249,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba0fe7708bcc..4e3fe00a2e9b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -63,6 +63,7 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", + "spark.decisionTree", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", @@ -84,6 +85,7 @@ exportClasses("SparkDataFrame") exportMethods("arrange", "as.data.frame", "attach", + "broadcast", "cache", "checkpoint", "coalesce", @@ -413,6 +415,8 @@ export("as.DataFrame", "print.summary.GeneralizedLinearRegressionModel", "read.ml", "print.summary.KSTest", + "print.summary.DecisionTreeRegressionModel", + "print.summary.DecisionTreeClassificationModel", "print.summary.RandomForestRegressionModel", "print.summary.RandomForestClassificationModel", "print.summary.GBTRegressionModel", @@ -451,6 +455,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.DecisionTreeRegressionModel) +S3method(print, summary.DecisionTreeClassificationModel) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) S3method(print, summary.GBTRegressionModel) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1c8869202f67..166b39813c14 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -549,7 +549,7 @@ setMethod("registerTempTable", #' sparkR.session() #' df <- read.df(path, "parquet") #' df2 <- read.df(path2, "parquet") -#' createOrReplaceTempView(df, "table1") +#' saveAsTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} #' @note insertInto since 1.4.0 @@ -1125,7 +1125,8 @@ setMethod("dim", #' path <- "path/to/file.json" #' df <- read.json(path) #' collected <- collect(df) -#' firstName <- collected[[1]]$name +#' class(collected) +#' firstName <- names(collected)[1] #' } #' @note collect since 1.4.0 setMethod("collect", @@ -2814,7 +2815,7 @@ setMethod("except", #' path <- "path/to/file.json" #' df <- read.json(path) #' write.df(df, "myfile", "parquet", "overwrite") -#' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) +#' saveDF(df, parquetPath2, "parquet", mode = "append", mergeSchema = TRUE) #' } #' @note write.df since 1.4.0 setMethod("write.df", @@ -3097,8 +3098,8 @@ setMethod("fillna", #' @family SparkDataFrame functions #' @aliases as.data.frame,SparkDataFrame-method #' @rdname as.data.frame -#' @examples \dontrun{ -#' +#' @examples +#' \dontrun{ #' irisDF <- createDataFrame(iris) #' df <- as.data.frame(irisDF[irisDF$Species == "setosa", ]) #' } @@ -3175,7 +3176,8 @@ setMethod("with", #' @aliases str,SparkDataFrame-method #' @family SparkDataFrame functions #' @param object a SparkDataFrame -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' # Create a SparkDataFrame from the Iris dataset #' irisDF <- createDataFrame(iris) #' @@ -3667,8 +3669,8 @@ setMethod("checkpoint", #' mean(cube(df, "cyl", "gear", "am"), "mpg") #' #' # Following calls are equivalent -#' agg(cube(carsDF), mean(carsDF$mpg)) -#' agg(carsDF, mean(carsDF$mpg)) +#' agg(cube(df), mean(df$mpg)) +#' agg(df, mean(df$mpg)) #' } #' @note cube since 2.3.0 #' @seealso \link{agg}, \link{groupBy}, \link{rollup} @@ -3702,8 +3704,8 @@ setMethod("cube", #' mean(rollup(df, "cyl", "gear", "am"), "mpg") #' #' # Following calls are equivalent -#' agg(rollup(carsDF), mean(carsDF$mpg)) -#' agg(carsDF, mean(carsDF$mpg)) +#' agg(rollup(df), mean(df$mpg)) +#' agg(df, mean(df$mpg)) #' } #' @note rollup since 2.3.0 #' @seealso \link{agg}, \link{cube}, \link{groupBy} @@ -3745,3 +3747,56 @@ setMethod("hint", jdf <- callJMethod(x@sdf, "hint", name, parameters) dataFrame(jdf) }) + +#' alias +#' +#' @aliases alias,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname alias +#' @name alias +#' @export +#' @examples +#' \dontrun{ +#' df <- alias(createDataFrame(mtcars), "mtcars") +#' avg_mpg <- alias(agg(groupBy(df, df$cyl), avg(df$mpg)), "avg_mpg") +#' +#' head(select(df, column("mtcars.mpg"))) +#' head(join(df, avg_mpg, column("mtcars.cyl") == column("avg_mpg.cyl"))) +#' } +#' @note alias(SparkDataFrame) since 2.3.0 +setMethod("alias", + signature(object = "SparkDataFrame"), + function(object, data) { + stopifnot(is.character(data)) + sdf <- callJMethod(object@sdf, "alias", data) + dataFrame(sdf) + }) + +#' broadcast +#' +#' Return a new SparkDataFrame marked as small enough for use in broadcast joins. +#' +#' Equivalent to \code{hint(x, "broadcast")}. +#' +#' @param x a SparkDataFrame. +#' @return a SparkDataFrame. +#' +#' @aliases broadcast,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname broadcast +#' @name broadcast +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl)) +#' } +#' @note broadcast since 2.3.0 +setMethod("broadcast", + signature(x = "SparkDataFrame"), + function(x) { + sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) + dataFrame(sdf) + }) diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 4ac83c29c6f7..81beac9ea992 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -203,7 +203,8 @@ setMethod("rangeBetween", #' @aliases over,Column,WindowSpec-method #' @family colum_func #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # Partition by am (transmission) and order by hp (horsepower) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 147ee4b6887b..a5c2ea81f249 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -130,19 +130,20 @@ createMethods <- function() { createMethods() -#' alias -#' -#' Set a new name for a column -#' -#' @param object Column to rename -#' @param data new name to use -#' #' @rdname alias #' @name alias #' @aliases alias,Column-method #' @family colum_func #' @export -#' @note alias since 1.4.0 +#' @examples +#' \dontrun{ +#' df <- createDataFrame(iris) +#' +#' head(select( +#' df, alias(df$Sepal_Length, "slength"), alias(df$Petal_Length, "plength") +#' )) +#' } +#' @note alias(Column) since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { @@ -244,7 +245,8 @@ setMethod("between", signature(x = "Column"), #' @family colum_func #' @aliases cast,Column-method #' -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' cast(df$age, "string") #' } #' @note cast since 1.4.0 diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 50856e3d9856..8349b57a30a9 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -258,7 +258,7 @@ includePackage <- function(sc, pkg) { #' #' # Large Matrix object that we want to broadcast #' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -#' randomMatBr <- broadcast(sc, randomMat) +#' randomMatBr <- broadcastRDD(sc, randomMat) #' #' # Use the broadcast variable inside the function #' useBroadcast <- function(x) { @@ -266,7 +266,7 @@ includePackage <- function(sc, pkg) { #' } #' sumRDD <- lapply(rdd, useBroadcast) #'} -broadcast <- function(sc, object) { +broadcastRDD <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5f9d11475c94..06a90192bb12 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -24,7 +24,7 @@ NULL #' If the parameter is a \linkS4class{Column}, it is returned unchanged. #' #' @param x a literal value or a Column. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname lit #' @name lit #' @export @@ -52,7 +52,7 @@ setMethod("lit", signature("ANY"), #' #' @rdname abs #' @name abs -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @examples \dontrun{abs(df$c)} #' @aliases abs,Column-method @@ -73,7 +73,7 @@ setMethod("abs", #' #' @rdname acos #' @name acos -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{acos(df$c)} #' @aliases acos,Column-method @@ -113,7 +113,7 @@ setMethod("approxCountDistinct", #' #' @rdname ascii #' @name ascii -#' @family string_funcs +#' @family string functions #' @export #' @aliases ascii,Column-method #' @examples \dontrun{\dontrun{ascii(df$c)}} @@ -134,7 +134,7 @@ setMethod("ascii", #' #' @rdname asin #' @name asin -#' @family math_funcs +#' @family math functions #' @export #' @aliases asin,Column-method #' @examples \dontrun{asin(df$c)} @@ -154,7 +154,7 @@ setMethod("asin", #' #' @rdname atan #' @name atan -#' @family math_funcs +#' @family math functions #' @export #' @aliases atan,Column-method #' @examples \dontrun{atan(df$c)} @@ -172,7 +172,7 @@ setMethod("atan", #' #' @rdname avg #' @name avg -#' @family agg_funcs +#' @family aggregate functions #' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} @@ -193,7 +193,7 @@ setMethod("avg", #' #' @rdname base64 #' @name base64 -#' @family string_funcs +#' @family string functions #' @export #' @aliases base64,Column-method #' @examples \dontrun{base64(df$c)} @@ -214,7 +214,7 @@ setMethod("base64", #' #' @rdname bin #' @name bin -#' @family math_funcs +#' @family math functions #' @export #' @aliases bin,Column-method #' @examples \dontrun{bin(df$c)} @@ -234,7 +234,7 @@ setMethod("bin", #' #' @rdname bitwiseNOT #' @name bitwiseNOT -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @aliases bitwiseNOT,Column-method #' @examples \dontrun{bitwiseNOT(df$c)} @@ -254,7 +254,7 @@ setMethod("bitwiseNOT", #' #' @rdname cbrt #' @name cbrt -#' @family math_funcs +#' @family math functions #' @export #' @aliases cbrt,Column-method #' @examples \dontrun{cbrt(df$c)} @@ -274,7 +274,7 @@ setMethod("cbrt", #' #' @rdname ceil #' @name ceil -#' @family math_funcs +#' @family math functions #' @export #' @aliases ceil,Column-method #' @examples \dontrun{ceil(df$c)} @@ -292,7 +292,7 @@ setMethod("ceil", #' #' @rdname coalesce #' @name coalesce -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @aliases coalesce,Column-method #' @examples \dontrun{coalesce(df$c, df$d, df$e)} @@ -324,7 +324,7 @@ col <- function(x) { #' #' @rdname column #' @name column -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} @@ -342,7 +342,7 @@ setMethod("column", #' #' @rdname corr #' @name corr -#' @family math_funcs +#' @family math functions #' @export #' @aliases corr,Column-method #' @examples \dontrun{corr(df$c, df$d)} @@ -360,7 +360,7 @@ setMethod("corr", signature(x = "Column"), #' #' @rdname cov #' @name cov -#' @family math_funcs +#' @family math functions #' @export #' @aliases cov,characterOrColumn-method #' @examples @@ -404,7 +404,7 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname covar_pop #' @name covar_pop -#' @family math_funcs +#' @family math functions #' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @examples @@ -432,7 +432,7 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' @rdname cos #' @name cos -#' @family math_funcs +#' @family math functions #' @aliases cos,Column-method #' @export #' @examples \dontrun{cos(df$c)} @@ -452,7 +452,7 @@ setMethod("cos", #' #' @rdname cosh #' @name cosh -#' @family math_funcs +#' @family math functions #' @aliases cosh,Column-method #' @export #' @examples \dontrun{cosh(df$c)} @@ -471,7 +471,7 @@ setMethod("cosh", #' #' @rdname count #' @name count -#' @family agg_funcs +#' @family aggregate functions #' @aliases count,Column-method #' @export #' @examples \dontrun{count(df$c)} @@ -492,7 +492,7 @@ setMethod("count", #' #' @rdname crc32 #' @name crc32 -#' @family misc_funcs +#' @family misc functions #' @aliases crc32,Column-method #' @export #' @examples \dontrun{crc32(df$c)} @@ -513,7 +513,7 @@ setMethod("crc32", #' #' @rdname hash #' @name hash -#' @family misc_funcs +#' @family misc functions #' @aliases hash,Column-method #' @export #' @examples \dontrun{hash(df$c)} @@ -537,7 +537,7 @@ setMethod("hash", #' #' @rdname dayofmonth #' @name dayofmonth -#' @family datetime_funcs +#' @family date time functions #' @aliases dayofmonth,Column-method #' @export #' @examples \dontrun{dayofmonth(df$c)} @@ -557,7 +557,7 @@ setMethod("dayofmonth", #' #' @rdname dayofyear #' @name dayofyear -#' @family datetime_funcs +#' @family date time functions #' @aliases dayofyear,Column-method #' @export #' @examples \dontrun{dayofyear(df$c)} @@ -579,7 +579,7 @@ setMethod("dayofyear", #' #' @rdname decode #' @name decode -#' @family string_funcs +#' @family string functions #' @aliases decode,Column,character-method #' @export #' @examples \dontrun{decode(df$c, "UTF-8")} @@ -601,7 +601,7 @@ setMethod("decode", #' #' @rdname encode #' @name encode -#' @family string_funcs +#' @family string functions #' @aliases encode,Column,character-method #' @export #' @examples \dontrun{encode(df$c, "UTF-8")} @@ -621,7 +621,7 @@ setMethod("encode", #' #' @rdname exp #' @name exp -#' @family math_funcs +#' @family math functions #' @aliases exp,Column-method #' @export #' @examples \dontrun{exp(df$c)} @@ -642,7 +642,7 @@ setMethod("exp", #' @rdname expm1 #' @name expm1 #' @aliases expm1,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{expm1(df$c)} #' @note expm1 since 1.5.0 @@ -662,7 +662,7 @@ setMethod("expm1", #' @rdname factorial #' @name factorial #' @aliases factorial,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{factorial(df$c)} #' @note factorial since 1.5.0 @@ -686,7 +686,7 @@ setMethod("factorial", #' @rdname first #' @name first #' @aliases first,characterOrColumn-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples #' \dontrun{ @@ -715,7 +715,7 @@ setMethod("first", #' @rdname floor #' @name floor #' @aliases floor,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{floor(df$c)} #' @note floor since 1.5.0 @@ -734,7 +734,7 @@ setMethod("floor", #' #' @rdname hex #' @name hex -#' @family math_funcs +#' @family math functions #' @aliases hex,Column-method #' @export #' @examples \dontrun{hex(df$c)} @@ -755,7 +755,7 @@ setMethod("hex", #' @rdname hour #' @name hour #' @aliases hour,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{hour(df$c)} #' @note hour since 1.5.0 @@ -777,7 +777,7 @@ setMethod("hour", #' #' @rdname initcap #' @name initcap -#' @family string_funcs +#' @family string functions #' @aliases initcap,Column-method #' @export #' @examples \dontrun{initcap(df$c)} @@ -797,7 +797,7 @@ setMethod("initcap", #' #' @rdname is.nan #' @name is.nan -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases is.nan,Column-method #' @export #' @examples @@ -832,7 +832,7 @@ setMethod("isnan", #' @rdname kurtosis #' @name kurtosis #' @aliases kurtosis,Column-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples \dontrun{kurtosis(df$c)} #' @note kurtosis since 1.6.0 @@ -858,7 +858,7 @@ setMethod("kurtosis", #' @rdname last #' @name last #' @aliases last,characterOrColumn-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples #' \dontrun{ @@ -889,7 +889,7 @@ setMethod("last", #' @rdname last_day #' @name last_day #' @aliases last_day,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{last_day(df$c)} #' @note last_day since 1.5.0 @@ -909,7 +909,7 @@ setMethod("last_day", #' @rdname length #' @name length #' @aliases length,Column-method -#' @family string_funcs +#' @family string functions #' @export #' @examples \dontrun{length(df$c)} #' @note length since 1.5.0 @@ -929,7 +929,7 @@ setMethod("length", #' @rdname log #' @name log #' @aliases log,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{log(df$c)} #' @note log since 1.5.0 @@ -948,7 +948,7 @@ setMethod("log", #' #' @rdname log10 #' @name log10 -#' @family math_funcs +#' @family math functions #' @aliases log10,Column-method #' @export #' @examples \dontrun{log10(df$c)} @@ -968,7 +968,7 @@ setMethod("log10", #' #' @rdname log1p #' @name log1p -#' @family math_funcs +#' @family math functions #' @aliases log1p,Column-method #' @export #' @examples \dontrun{log1p(df$c)} @@ -988,7 +988,7 @@ setMethod("log1p", #' #' @rdname log2 #' @name log2 -#' @family math_funcs +#' @family math functions #' @aliases log2,Column-method #' @export #' @examples \dontrun{log2(df$c)} @@ -1008,7 +1008,7 @@ setMethod("log2", #' #' @rdname lower #' @name lower -#' @family string_funcs +#' @family string functions #' @aliases lower,Column-method #' @export #' @examples \dontrun{lower(df$c)} @@ -1028,7 +1028,7 @@ setMethod("lower", #' #' @rdname ltrim #' @name ltrim -#' @family string_funcs +#' @family string functions #' @aliases ltrim,Column-method #' @export #' @examples \dontrun{ltrim(df$c)} @@ -1048,7 +1048,7 @@ setMethod("ltrim", #' #' @rdname max #' @name max -#' @family agg_funcs +#' @family aggregate functions #' @aliases max,Column-method #' @export #' @examples \dontrun{max(df$c)} @@ -1069,7 +1069,7 @@ setMethod("max", #' #' @rdname md5 #' @name md5 -#' @family misc_funcs +#' @family misc functions #' @aliases md5,Column-method #' @export #' @examples \dontrun{md5(df$c)} @@ -1090,7 +1090,7 @@ setMethod("md5", #' #' @rdname mean #' @name mean -#' @family agg_funcs +#' @family aggregate functions #' @aliases mean,Column-method #' @export #' @examples \dontrun{mean(df$c)} @@ -1111,7 +1111,7 @@ setMethod("mean", #' @rdname min #' @name min #' @aliases min,Column-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples \dontrun{min(df$c)} #' @note min since 1.5.0 @@ -1131,7 +1131,7 @@ setMethod("min", #' @rdname minute #' @name minute #' @aliases minute,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{minute(df$c)} #' @note minute since 1.5.0 @@ -1160,7 +1160,7 @@ setMethod("minute", #' @rdname monotonically_increasing_id #' @aliases monotonically_increasing_id,missing-method #' @name monotonically_increasing_id -#' @family misc_funcs +#' @family misc functions #' @export #' @examples \dontrun{select(df, monotonically_increasing_id())} setMethod("monotonically_increasing_id", @@ -1179,7 +1179,7 @@ setMethod("monotonically_increasing_id", #' @rdname month #' @name month #' @aliases month,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{month(df$c)} #' @note month since 1.5.0 @@ -1198,7 +1198,7 @@ setMethod("month", #' #' @rdname negate #' @name negate -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases negate,Column-method #' @export #' @examples \dontrun{negate(df$c)} @@ -1218,7 +1218,7 @@ setMethod("negate", #' #' @rdname quarter #' @name quarter -#' @family datetime_funcs +#' @family date time functions #' @aliases quarter,Column-method #' @export #' @examples \dontrun{quarter(df$c)} @@ -1238,7 +1238,7 @@ setMethod("quarter", #' #' @rdname reverse #' @name reverse -#' @family string_funcs +#' @family string functions #' @aliases reverse,Column-method #' @export #' @examples \dontrun{reverse(df$c)} @@ -1259,7 +1259,7 @@ setMethod("reverse", #' #' @rdname rint #' @name rint -#' @family math_funcs +#' @family math functions #' @aliases rint,Column-method #' @export #' @examples \dontrun{rint(df$c)} @@ -1279,7 +1279,7 @@ setMethod("rint", #' #' @rdname round #' @name round -#' @family math_funcs +#' @family math functions #' @aliases round,Column-method #' @export #' @examples \dontrun{round(df$c)} @@ -1305,7 +1305,7 @@ setMethod("round", #' @param ... further arguments to be passed to or from other methods. #' @rdname bround #' @name bround -#' @family math_funcs +#' @family math functions #' @aliases bround,Column-method #' @export #' @examples \dontrun{bround(df$c, 0)} @@ -1326,7 +1326,7 @@ setMethod("bround", #' #' @rdname rtrim #' @name rtrim -#' @family string_funcs +#' @family string functions #' @aliases rtrim,Column-method #' @export #' @examples \dontrun{rtrim(df$c)} @@ -1346,7 +1346,7 @@ setMethod("rtrim", #' @param na.rm currently not used. #' @rdname sd #' @name sd -#' @family agg_funcs +#' @family aggregate functions #' @aliases sd,Column-method #' @seealso \link{stddev_pop}, \link{stddev_samp} #' @export @@ -1372,7 +1372,7 @@ setMethod("sd", #' #' @rdname second #' @name second -#' @family datetime_funcs +#' @family date time functions #' @aliases second,Column-method #' @export #' @examples \dontrun{second(df$c)} @@ -1393,7 +1393,7 @@ setMethod("second", #' #' @rdname sha1 #' @name sha1 -#' @family misc_funcs +#' @family misc functions #' @aliases sha1,Column-method #' @export #' @examples \dontrun{sha1(df$c)} @@ -1414,7 +1414,7 @@ setMethod("sha1", #' @rdname sign #' @name signum #' @aliases signum,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{signum(df$c)} #' @note signum since 1.5.0 @@ -1433,7 +1433,7 @@ setMethod("signum", #' #' @rdname sin #' @name sin -#' @family math_funcs +#' @family math functions #' @aliases sin,Column-method #' @export #' @examples \dontrun{sin(df$c)} @@ -1453,7 +1453,7 @@ setMethod("sin", #' #' @rdname sinh #' @name sinh -#' @family math_funcs +#' @family math functions #' @aliases sinh,Column-method #' @export #' @examples \dontrun{sinh(df$c)} @@ -1473,7 +1473,7 @@ setMethod("sinh", #' #' @rdname skewness #' @name skewness -#' @family agg_funcs +#' @family aggregate functions #' @aliases skewness,Column-method #' @export #' @examples \dontrun{skewness(df$c)} @@ -1493,7 +1493,7 @@ setMethod("skewness", #' #' @rdname soundex #' @name soundex -#' @family string_funcs +#' @family string functions #' @aliases soundex,Column-method #' @export #' @examples \dontrun{soundex(df$c)} @@ -1546,7 +1546,7 @@ setMethod("stddev", #' #' @rdname stddev_pop #' @name stddev_pop -#' @family agg_funcs +#' @family aggregate functions #' @aliases stddev_pop,Column-method #' @seealso \link{sd}, \link{stddev_samp} #' @export @@ -1567,7 +1567,7 @@ setMethod("stddev_pop", #' #' @rdname stddev_samp #' @name stddev_samp -#' @family agg_funcs +#' @family aggregate functions #' @aliases stddev_samp,Column-method #' @seealso \link{stddev_pop}, \link{sd} #' @export @@ -1589,7 +1589,7 @@ setMethod("stddev_samp", #' #' @rdname struct #' @name struct -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases struct,characterOrColumn-method #' @export #' @examples @@ -1618,7 +1618,7 @@ setMethod("struct", #' #' @rdname sqrt #' @name sqrt -#' @family math_funcs +#' @family math functions #' @aliases sqrt,Column-method #' @export #' @examples \dontrun{sqrt(df$c)} @@ -1638,7 +1638,7 @@ setMethod("sqrt", #' #' @rdname sum #' @name sum -#' @family agg_funcs +#' @family aggregate functions #' @aliases sum,Column-method #' @export #' @examples \dontrun{sum(df$c)} @@ -1658,7 +1658,7 @@ setMethod("sum", #' #' @rdname sumDistinct #' @name sumDistinct -#' @family agg_funcs +#' @family aggregate functions #' @aliases sumDistinct,Column-method #' @export #' @examples \dontrun{sumDistinct(df$c)} @@ -1678,7 +1678,7 @@ setMethod("sumDistinct", #' #' @rdname tan #' @name tan -#' @family math_funcs +#' @family math functions #' @aliases tan,Column-method #' @export #' @examples \dontrun{tan(df$c)} @@ -1698,7 +1698,7 @@ setMethod("tan", #' #' @rdname tanh #' @name tanh -#' @family math_funcs +#' @family math functions #' @aliases tanh,Column-method #' @export #' @examples \dontrun{tanh(df$c)} @@ -1718,7 +1718,7 @@ setMethod("tanh", #' #' @rdname toDegrees #' @name toDegrees -#' @family math_funcs +#' @family math functions #' @aliases toDegrees,Column-method #' @export #' @examples \dontrun{toDegrees(df$c)} @@ -1738,7 +1738,7 @@ setMethod("toDegrees", #' #' @rdname toRadians #' @name toRadians -#' @family math_funcs +#' @family math functions #' @aliases toRadians,Column-method #' @export #' @examples \dontrun{toRadians(df$c)} @@ -1757,14 +1757,15 @@ setMethod("toRadians", #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. -#' The default format is 'yyyy-MM-dd'. +#' By default, it follows casting rules to a DateType if the format is omitted +#' (equivalent to \code{cast(df$x, "date")}). #' #' @param x Column to parse. #' @param format string to use to parse x Column to DateType. (optional) #' #' @rdname to_date #' @name to_date -#' @family datetime_funcs +#' @family date time functions #' @aliases to_date,Column,missing-method #' @export #' @examples @@ -1782,7 +1783,7 @@ setMethod("to_date", #' @rdname to_date #' @name to_date -#' @family datetime_funcs +#' @family date time functions #' @aliases to_date,Column,character-method #' @export #' @note to_date(Column, character) since 2.2.0 @@ -1802,7 +1803,7 @@ setMethod("to_date", #' @param ... additional named properties to control how it is converted, accepts the same options #' as the JSON data source. #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname to_json #' @name to_json #' @aliases to_json,Column-method @@ -1832,14 +1833,15 @@ setMethod("to_json", signature(x = "Column"), #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. -#' The default format is 'yyyy-MM-dd HH:mm:ss'. +#' By default, it follows casting rules to a TimestampType if the format is omitted +#' (equivalent to \code{cast(df$x, "timestamp")}). #' #' @param x Column to parse. -#' @param format string to use to parse x Column to DateType. (optional) +#' @param format string to use to parse x Column to TimestampType. (optional) #' #' @rdname to_timestamp #' @name to_timestamp -#' @family datetime_funcs +#' @family date time functions #' @aliases to_timestamp,Column,missing-method #' @export #' @examples @@ -1857,7 +1859,7 @@ setMethod("to_timestamp", #' @rdname to_timestamp #' @name to_timestamp -#' @family datetime_funcs +#' @family date time functions #' @aliases to_timestamp,Column,character-method #' @export #' @note to_timestamp(Column, character) since 2.2.0 @@ -1876,7 +1878,7 @@ setMethod("to_timestamp", #' #' @rdname trim #' @name trim -#' @family string_funcs +#' @family string functions #' @aliases trim,Column-method #' @export #' @examples \dontrun{trim(df$c)} @@ -1897,7 +1899,7 @@ setMethod("trim", #' #' @rdname unbase64 #' @name unbase64 -#' @family string_funcs +#' @family string functions #' @aliases unbase64,Column-method #' @export #' @examples \dontrun{unbase64(df$c)} @@ -1918,7 +1920,7 @@ setMethod("unbase64", #' #' @rdname unhex #' @name unhex -#' @family math_funcs +#' @family math functions #' @aliases unhex,Column-method #' @export #' @examples \dontrun{unhex(df$c)} @@ -1938,7 +1940,7 @@ setMethod("unhex", #' #' @rdname upper #' @name upper -#' @family string_funcs +#' @family string functions #' @aliases upper,Column-method #' @export #' @examples \dontrun{upper(df$c)} @@ -1958,7 +1960,7 @@ setMethod("upper", #' @param y,na.rm,use currently not used. #' @rdname var #' @name var -#' @family agg_funcs +#' @family aggregate functions #' @aliases var,Column-method #' @seealso \link{var_pop}, \link{var_samp} #' @export @@ -1995,7 +1997,7 @@ setMethod("variance", #' #' @rdname var_pop #' @name var_pop -#' @family agg_funcs +#' @family aggregate functions #' @aliases var_pop,Column-method #' @seealso \link{var}, \link{var_samp} #' @export @@ -2017,7 +2019,7 @@ setMethod("var_pop", #' @rdname var_samp #' @name var_samp #' @aliases var_samp,Column-method -#' @family agg_funcs +#' @family aggregate functions #' @seealso \link{var_pop}, \link{var} #' @export #' @examples \dontrun{var_samp(df$c)} @@ -2038,7 +2040,7 @@ setMethod("var_samp", #' @rdname weekofyear #' @name weekofyear #' @aliases weekofyear,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{weekofyear(df$c)} #' @note weekofyear since 1.5.0 @@ -2057,7 +2059,7 @@ setMethod("weekofyear", #' #' @rdname year #' @name year -#' @family datetime_funcs +#' @family date time functions #' @aliases year,Column-method #' @export #' @examples \dontrun{year(df$c)} @@ -2079,7 +2081,7 @@ setMethod("year", #' #' @rdname atan2 #' @name atan2 -#' @family math_funcs +#' @family math functions #' @aliases atan2,Column-method #' @export #' @examples \dontrun{atan2(df$c, x)} @@ -2103,7 +2105,7 @@ setMethod("atan2", signature(y = "Column"), #' @rdname datediff #' @name datediff #' @aliases datediff,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{datediff(df$c, x)} #' @note datediff since 1.5.0 @@ -2125,7 +2127,7 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname hypot #' @name hypot -#' @family math_funcs +#' @family math functions #' @aliases hypot,Column-method #' @export #' @examples \dontrun{hypot(df$c, x)} @@ -2148,7 +2150,7 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname levenshtein #' @name levenshtein -#' @family string_funcs +#' @family string functions #' @aliases levenshtein,Column-method #' @export #' @examples \dontrun{levenshtein(df$c, x)} @@ -2171,7 +2173,7 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname months_between #' @name months_between -#' @family datetime_funcs +#' @family date time functions #' @aliases months_between,Column-method #' @export #' @examples \dontrun{months_between(df$c, x)} @@ -2195,7 +2197,7 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname nanvl #' @name nanvl -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases nanvl,Column-method #' @export #' @examples \dontrun{nanvl(df$c, x)} @@ -2219,7 +2221,7 @@ setMethod("nanvl", signature(y = "Column"), #' @rdname pmod #' @name pmod #' @docType methods -#' @family math_funcs +#' @family math functions #' @aliases pmod,Column-method #' @export #' @examples \dontrun{pmod(df$c, x)} @@ -2257,7 +2259,7 @@ setMethod("approxCountDistinct", #' @param x Column to compute on #' @param ... other columns #' -#' @family agg_funcs +#' @family aggregate functions #' @rdname countDistinct #' @name countDistinct #' @aliases countDistinct,Column-method @@ -2285,7 +2287,7 @@ setMethod("countDistinct", #' @param x Column to compute on #' @param ... other columns #' -#' @family string_funcs +#' @family string functions #' @rdname concat #' @name concat #' @aliases concat,Column-method @@ -2311,7 +2313,7 @@ setMethod("concat", #' @param x Column to compute on #' @param ... other columns #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname greatest #' @name greatest #' @aliases greatest,Column-method @@ -2338,7 +2340,7 @@ setMethod("greatest", #' @param x Column to compute on #' @param ... other columns #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname least #' @aliases least,Column-method #' @name least @@ -2422,7 +2424,7 @@ setMethod("n", signature(x = "Column"), #' @param y Column to compute on. #' @param x date format specification. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname date_format #' @name date_format #' @aliases date_format,Column,character-method @@ -2447,7 +2449,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' @param ... additional named properties to control how the json is parsed, accepts the same #' options as the JSON data source. #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname from_json #' @name from_json #' @aliases from_json,Column,structType-method @@ -2482,7 +2484,7 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), #' @param y Column to compute on. #' @param x time zone to use. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname from_utc_timestamp #' @name from_utc_timestamp #' @aliases from_utc_timestamp,Column,character-method @@ -2505,7 +2507,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' #' @param y column to check #' @param x substring to check -#' @family string_funcs +#' @family string functions #' @aliases instr,Column,character-method #' @rdname instr #' @name instr @@ -2532,7 +2534,7 @@ setMethod("instr", signature(y = "Column", x = "character"), #' @param y Column to compute on. #' @param x Day of the week string. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname next_day #' @name next_day #' @aliases next_day,Column,character-method @@ -2557,7 +2559,7 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' @param y Column to compute on #' @param x timezone to use #' -#' @family datetime_funcs +#' @family date time functions #' @rdname to_utc_timestamp #' @name to_utc_timestamp #' @aliases to_utc_timestamp,Column,character-method @@ -2578,7 +2580,7 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' @param x Number of months to add #' #' @name add_months -#' @family datetime_funcs +#' @family date time functions #' @rdname add_months #' @aliases add_months,Column,numeric-method #' @export @@ -2597,7 +2599,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' @param y Column to compute on #' @param x Number of days to add #' -#' @family datetime_funcs +#' @family date time functions #' @rdname date_add #' @name date_add #' @aliases date_add,Column,numeric-method @@ -2617,7 +2619,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @param y Column to compute on #' @param x Number of days to substract #' -#' @family datetime_funcs +#' @family date time functions #' @rdname date_sub #' @name date_sub #' @aliases date_sub,Column,numeric-method @@ -2640,7 +2642,7 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' #' @param y column to format #' @param x number of decimal place to format to -#' @family string_funcs +#' @family string functions #' @rdname format_number #' @name format_number #' @aliases format_number,Column,numeric-method @@ -2662,7 +2664,7 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @param y column to compute SHA-2 on. #' @param x one of 224, 256, 384, or 512. -#' @family misc_funcs +#' @family misc functions #' @rdname sha2 #' @name sha2 #' @aliases sha2,Column,numeric-method @@ -2683,7 +2685,7 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' @param y column to compute on. #' @param x number of bits to shift. #' -#' @family math_funcs +#' @family math functions #' @rdname shiftLeft #' @name shiftLeft #' @aliases shiftLeft,Column,numeric-method @@ -2706,7 +2708,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' @param y column to compute on. #' @param x number of bits to shift. #' -#' @family math_funcs +#' @family math functions #' @rdname shiftRight #' @name shiftRight #' @aliases shiftRight,Column,numeric-method @@ -2729,7 +2731,7 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' @param y column to compute on. #' @param x number of bits to shift. #' -#' @family math_funcs +#' @family math functions #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned #' @aliases shiftRightUnsigned,Column,numeric-method @@ -2753,7 +2755,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @param ... other columns to concatenate. #' -#' @family string_funcs +#' @family string functions #' @rdname concat_ws #' @name concat_ws #' @aliases concat_ws,character,Column-method @@ -2775,7 +2777,7 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param fromBase base to convert from. #' @param toBase base to convert to. #' -#' @family math_funcs +#' @family math functions #' @rdname conv #' @aliases conv,Column,numeric,numeric-method #' @name conv @@ -2798,7 +2800,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' SparkDataFrame.selectExpr #' #' @param x an expression character object to be parsed. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname expr #' @aliases expr,character-method #' @name expr @@ -2818,7 +2820,7 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @param x a Column. #' @param ... additional Column(s). -#' @family string_funcs +#' @family string functions #' @rdname format_string #' @name format_string #' @aliases format_string,character,Column-method @@ -2845,7 +2847,7 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ #' Customizing Formats} for available options. #' @param ... further arguments to be passed to or from other methods. -#' @family datetime_funcs +#' @family date time functions #' @rdname from_unixtime #' @name from_unixtime #' @aliases from_unixtime,Column-method @@ -2890,7 +2892,7 @@ setMethod("from_unixtime", signature(x = "Column"), #' @param ... further arguments to be passed to or from other methods. #' @return An output column of struct called 'window' by default with the nested columns 'start' #' and 'end'. -#' @family datetime_funcs +#' @family date time functions #' @rdname window #' @name window #' @aliases window,Column-method @@ -2946,7 +2948,7 @@ setMethod("window", signature(x = "Column"), #' @param str a Column where matches are sought for each entry. #' @param pos start position of search. #' @param ... further arguments to be passed to or from other methods. -#' @family string_funcs +#' @family string functions #' @rdname locate #' @aliases locate,character,Column-method #' @name locate @@ -2968,7 +2970,7 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param x the string Column to be left-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string_funcs +#' @family string functions #' @rdname lpad #' @aliases lpad,Column,numeric,character-method #' @name lpad @@ -2989,7 +2991,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname rand #' @name rand #' @aliases rand,missing-method @@ -3019,7 +3021,7 @@ setMethod("rand", signature(seed = "numeric"), #' the standard normal distribution. #' #' @param seed a random seed. Can be missing. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname randn #' @name randn #' @aliases randn,missing-method @@ -3051,7 +3053,7 @@ setMethod("randn", signature(seed = "numeric"), #' @param x a string Column. #' @param pattern a regular expression. #' @param idx a group index. -#' @family string_funcs +#' @family string functions #' @rdname regexp_extract #' @name regexp_extract #' @aliases regexp_extract,Column,character,numeric-method @@ -3074,7 +3076,7 @@ setMethod("regexp_extract", #' @param x a string Column. #' @param pattern a regular expression. #' @param replacement a character string that a matched \code{pattern} is replaced with. -#' @family string_funcs +#' @family string functions #' @rdname regexp_replace #' @name regexp_replace #' @aliases regexp_replace,Column,character,character-method @@ -3097,7 +3099,7 @@ setMethod("regexp_replace", #' @param x the string Column to be right-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string_funcs +#' @family string functions #' @rdname rpad #' @name rpad #' @aliases rpad,Column,numeric,character-method @@ -3124,7 +3126,7 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' @param count number of occurrences of \code{delim} before the substring is returned. #' A positive number means counting from the left, while negative means #' counting from the right. -#' @family string_funcs +#' @family string functions #' @rdname substring_index #' @aliases substring_index,Column,character,numeric-method #' @name substring_index @@ -3156,7 +3158,7 @@ setMethod("substring_index", #' @param replaceString a target string where each \code{matchingString} character will #' be replaced by the character in \code{replaceString} #' at the same location, if any. -#' @family string_funcs +#' @family string functions #' @rdname translate #' @name translate #' @aliases translate,Column,character,character-method @@ -3175,7 +3177,7 @@ setMethod("translate", #' #' Gets current Unix timestamp in seconds. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname unix_timestamp #' @name unix_timestamp #' @aliases unix_timestamp,missing,missing-method @@ -3225,7 +3227,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname when #' @name when #' @aliases when,Column-method @@ -3249,13 +3251,14 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param test a Column expression that describes the condition. #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname ifelse #' @name ifelse #' @aliases ifelse,Column-method #' @seealso \link{when} #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' ifelse(df$a > 1 & df$b > 2, 0, 1) #' ifelse(df$a > 1, df$a, 1) #' } @@ -3287,10 +3290,11 @@ setMethod("ifelse", #' #' @rdname cume_dist #' @name cume_dist -#' @family window_funcs +#' @family window functions #' @aliases cume_dist,missing-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' ws <- orderBy(windowPartitionBy("am"), "hp") #' out <- select(df, over(cume_dist(), ws), df$hp, df$am) @@ -3316,10 +3320,11 @@ setMethod("cume_dist", #' #' @rdname dense_rank #' @name dense_rank -#' @family window_funcs +#' @family window functions #' @aliases dense_rank,missing-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' ws <- orderBy(windowPartitionBy("am"), "hp") #' out <- select(df, over(dense_rank(), ws), df$hp, df$am) @@ -3348,9 +3353,10 @@ setMethod("dense_rank", #' @rdname lag #' @name lag #' @aliases lag,characterOrColumn-method -#' @family window_funcs +#' @family window functions #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # Partition by am (transmission) and order by hp (horsepower) @@ -3390,10 +3396,11 @@ setMethod("lag", #' #' @rdname lead #' @name lead -#' @family window_funcs +#' @family window functions #' @aliases lead,characterOrColumn,numeric-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # Partition by am (transmission) and order by hp (horsepower) @@ -3430,9 +3437,10 @@ setMethod("lead", #' @rdname ntile #' @name ntile #' @aliases ntile,numeric-method -#' @family window_funcs +#' @family window functions #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # Partition by am (transmission) and order by hp (horsepower) @@ -3461,10 +3469,11 @@ setMethod("ntile", #' #' @rdname percent_rank #' @name percent_rank -#' @family window_funcs +#' @family window functions #' @aliases percent_rank,missing-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' ws <- orderBy(windowPartitionBy("am"), "hp") #' out <- select(df, over(percent_rank(), ws), df$hp, df$am) @@ -3491,10 +3500,11 @@ setMethod("percent_rank", #' #' @rdname rank #' @name rank -#' @family window_funcs +#' @family window functions #' @aliases rank,missing-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' ws <- orderBy(windowPartitionBy("am"), "hp") #' out <- select(df, over(rank(), ws), df$hp, df$am) @@ -3529,9 +3539,10 @@ setMethod("rank", #' @rdname row_number #' @name row_number #' @aliases row_number,missing-method -#' @family window_funcs +#' @family window functions #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' ws <- orderBy(windowPartitionBy("am"), "hp") #' out <- select(df, over(row_number(), ws), df$hp, df$am) @@ -3555,7 +3566,7 @@ setMethod("row_number", #' @rdname array_contains #' @aliases array_contains,Column-method #' @name array_contains -#' @family collection_funcs +#' @family collection functions #' @export #' @examples \dontrun{array_contains(df$c, 1)} #' @note array_contains since 1.6.0 @@ -3574,7 +3585,7 @@ setMethod("array_contains", #' #' @rdname explode #' @name explode -#' @family collection_funcs +#' @family collection functions #' @aliases explode,Column-method #' @export #' @examples \dontrun{explode(df$c)} @@ -3595,7 +3606,7 @@ setMethod("explode", #' @rdname size #' @name size #' @aliases size,Column-method -#' @family collection_funcs +#' @family collection functions #' @export #' @examples \dontrun{size(df$c)} #' @note size since 1.5.0 @@ -3618,7 +3629,7 @@ setMethod("size", #' @rdname sort_array #' @name sort_array #' @aliases sort_array,Column-method -#' @family collection_funcs +#' @family collection functions #' @export #' @examples #' \dontrun{ @@ -3641,7 +3652,7 @@ setMethod("sort_array", #' #' @rdname posexplode #' @name posexplode -#' @family collection_funcs +#' @family collection functions #' @aliases posexplode,Column-method #' @export #' @examples \dontrun{posexplode(df$c)} @@ -3660,7 +3671,7 @@ setMethod("posexplode", #' @param x Column to compute on #' @param ... additional Column(s). #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname create_array #' @name create_array #' @aliases create_array,Column-method @@ -3688,7 +3699,7 @@ setMethod("create_array", #' @param x Column to compute on #' @param ... additional Column(s). #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname create_map #' @name create_map #' @aliases create_map,Column-method @@ -3714,7 +3725,7 @@ setMethod("create_map", #' #' @rdname collect_list #' @name collect_list -#' @family agg_funcs +#' @family aggregate functions #' @aliases collect_list,Column-method #' @export #' @examples \dontrun{collect_list(df$x)} @@ -3734,7 +3745,7 @@ setMethod("collect_list", #' #' @rdname collect_set #' @name collect_set -#' @family agg_funcs +#' @family aggregate functions #' @aliases collect_set,Column-method #' @export #' @examples \dontrun{collect_set(df$x)} @@ -3756,10 +3767,11 @@ setMethod("collect_set", #' @param pattern Java regular expression #' #' @rdname split_string -#' @family string_funcs +#' @family string functions #' @aliases split_string,Column-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- read.text("README.md") #' #' head(select(df, split_string(df$value, "\\s+"))) @@ -3785,10 +3797,11 @@ setMethod("split_string", #' @param n Number of repetitions #' #' @rdname repeat_string -#' @family string_funcs +#' @family string functions #' @aliases repeat_string,Column-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- read.text("README.md") #' #' first(select(df, repeat_string(df$value, 3))) @@ -3814,10 +3827,11 @@ setMethod("repeat_string", #' #' @rdname explode_outer #' @name explode_outer -#' @family collection_funcs +#' @family collection functions #' @aliases explode_outer,Column-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(data.frame( #' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") #' )) @@ -3842,10 +3856,11 @@ setMethod("explode_outer", #' #' @rdname posexplode_outer #' @name posexplode_outer -#' @family collection_funcs +#' @family collection functions #' @aliases posexplode_outer,Column-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(data.frame( #' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") #' )) @@ -3871,9 +3886,10 @@ setMethod("posexplode_outer", #' @rdname not #' @name not #' @aliases not,Column-method -#' @family normal_funcs +#' @family non-aggregate functions #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(data.frame( #' is_true = c(TRUE, FALSE, NA), #' flag = c(1, 0, 1) @@ -3903,10 +3919,11 @@ setMethod("not", #' #' @rdname grouping_bit #' @name grouping_bit -#' @family agg_funcs +#' @family aggregate functions #' @aliases grouping_bit,Column-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # With cube @@ -3944,10 +3961,11 @@ setMethod("grouping_bit", #' #' @rdname grouping_id #' @name grouping_id -#' @family agg_funcs +#' @family aggregate functions #' @aliases grouping_id,Column-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- createDataFrame(mtcars) #' #' # With cube @@ -3982,10 +4000,11 @@ setMethod("grouping_id", #' #' @rdname input_file_name #' @name input_file_name -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases input_file_name,missing-method #' @export -#' @examples \dontrun{ +#' @examples +#' \dontrun{ #' df <- read.text("README.md") #' #' head(select(df, input_file_name())) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e835ef3e4f40..5630d0c8a0df 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -387,6 +387,17 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +#' alias +#' +#' Returns a new SparkDataFrame or a Column with an alias set. Equivalent to SQL "AS" keyword. +#' +#' @name alias +#' @rdname alias +#' @param object x a SparkDataFrame or a Column +#' @param data new name to use +#' @return a SparkDataFrame or a Column +NULL + #' @rdname arrange #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) @@ -788,6 +799,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d #' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) +#' @rdname broadcast +#' @export +setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) + ###################### Column Methods ########################## #' @rdname columnfunctions @@ -1491,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) +#' @rdname spark.decisionTree +#' @export +setGeneric("spark.decisionTree", + function(data, formula, ...) { standardGeneric("spark.decisionTree") }) + #' @rdname spark.randomForest #' @export setGeneric("spark.randomForest", diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 4db9cc30fb0c..306a9b867653 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj" #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) -#' linear SVM Model +#' Linear SVM Model #' -#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package +#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package. +#' Currently only supports binary classification model with linear kernel. #' Users can print, make predictions on the produced model and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam The regularization parameter. +#' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. #' @param standardization Whether to standardize the training features before fitting the model. The coefficients @@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu new("LinearSVCModel", jobj = jobj) }) -# Predicted values based on an LinearSVCModel model +# Predicted values based on a LinearSVCModel model #' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LinearSVCModel. +#' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method #' @export @@ -124,13 +125,12 @@ setMethod("predict", signature(object = "LinearSVCModel"), predict_internal(object, newData) }) -# Get the summary of an LinearSVCModel +# Get the summary of a LinearSVCModel -#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. +#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}. #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list includes \code{coefficients} (coefficients of the fitted model), -#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), -#' \code{numFeatures} (number of features). +#' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method #' @export @@ -138,22 +138,14 @@ setMethod("predict", signature(object = "LinearSVCModel"), setMethod("summary", signature(object = "LinearSVCModel"), function(object) { jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - coefficients <- callJMethod(jobj, "coefficients") - nCol <- length(coefficients) / length(features) - coefficients <- matrix(unlist(coefficients), ncol = nCol) - intercept <- callJMethod(jobj, "intercept") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) numClasses <- callJMethod(jobj, "numClasses") numFeatures <- callJMethod(jobj, "numFeatures") - if (nCol == 1) { - colnames(coefficients) <- c("Estimate") - } else { - colnames(coefficients) <- unlist(labels) - } - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients, intercept = intercept, - numClasses = numClasses, numFeatures = numFeatures) + list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures) }) # Save fitted LinearSVCModel to the input path diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 82279be6fbe7..2f1220a75278 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a DecisionTreeRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel +#' @export +#' @note DecisionTreeRegressionModel since 2.3.0 +setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a DecisionTreeClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel +#' @export +#' @note DecisionTreeClassificationModel since 2.3.0 +setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) + # Create the summary of a tree ensemble model (eg. Random Forest, GBT) summary.treeEnsemble <- function(model) { jobj <- model@jobj @@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) { invisible(x) } +# Create the summary of a decision tree model +summary.decisionTree <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + maxDepth = maxDepth, + jobj = jobj) +} + +# Prints the summary of decision tree models +print.summary.decisionTree <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + #' Gradient Boosted Tree Model for Regression and Classification #' #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a @@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path function(object, path, overwrite = FALSE) { write_internal(object, path, overwrite) }) + +#' Decision Tree Model for Regression and Classification +#' +#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ +#' Decision Tree Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ +#' Decision Tree Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param seed integer seed for random number generation. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.decisionTree,SparkDataFrame,formula-method +#' @return \code{spark.decisionTree} returns a fitted Decision Tree model. +#' @rdname spark.decisionTree +#' @name spark.decisionTree +#' @export +#' @examples +#' \dontrun{ +#' # fit a Decision Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Decision Tree Classification Model +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification") +#' } +#' @note spark.decisionTree since 2.3.0 +setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Decision Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeRegressionModel-method +#' @export +#' @note summary(DecisionTreeRegressionModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeRegressionModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeRegressionModel" + ans + }) + +# Prints the summary of Decision Tree Regression Model + +#' @param x summary object of Decision Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeRegressionModel since 2.3.0 +print.summary.DecisionTreeRegressionModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Get the summary of a Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeClassificationModel-method +#' @export +#' @note summary(DecisionTreeClassificationModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeClassificationModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeClassificationModel" + ans + }) + +# Prints the summary of Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeClassificationModel since 2.3.0 +print.summary.DecisionTreeClassificationModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Makes predictions from a Decision Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeRegressionModel-method +#' @export +#' @note predict(DecisionTreeRegressionModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeClassificationModel-method +#' @export +#' @note predict(DecisionTreeClassificationModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Decision Tree Regression or Classification model to the input path. + +#' @param object A fitted Decision Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,DecisionTreeRegressionModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,DecisionTreeClassificationModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 5dfef8625061..a53c92c2c481 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -32,8 +32,9 @@ #' @rdname write.ml #' @name write.ml #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, @@ -48,8 +49,9 @@ NULL #' @rdname predict #' @name predict #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear} @@ -110,6 +112,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) { + new("DecisionTreeRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) { + new("DecisionTreeClassificationModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { new("GBTRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index d29af00affb9..ea45e394500e 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -907,3 +907,19 @@ basenameSansExtFromUrl <- function(url) { isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } + +is_cran <- function() { + !identical(Sys.getenv("NOT_CRAN"), "true") +} + +is_windows <- function() { + .Platform$OS.type == "windows" +} + +hadoop_home_set <- function() { + !identical(Sys.getenv("HADOOP_HOME"), "") +} + +not_cran_or_windows_with_hadoop <- function() { + !is_cran() && (!is_windows() || hadoop_home_set()) +} diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index c9615c8d4faf..e2241e03b55f 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.session() +sc <- sparkR.session(master = "local[1]") helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index 4bc935c79eb0..ac706261999f 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkR.session() +sparkR.session(master = "local[1]") run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index 518fb7bd9404..6e160fae1afe 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { skip_on_cran() diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 63f54e1af02b..00954fa31b0e 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,7 +18,7 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 25bb2b84266d..236cb3885445 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 504ded4fc862..2c96740df77b 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data @@ -29,7 +29,7 @@ test_that("using broadcast variable", { skip_on_cran() randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - randomMatBr <- broadcast(sc, randomMat) + randomMatBr <- broadcastRDD(sc, randomMat) useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 632a90d68177..f6d9f5423df0 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -60,7 +60,7 @@ test_that("repeatedly starting and stopping SparkR", { skip_on_cran() for (i in 1:4) { - sc <- suppressWarnings(sparkR.init()) + sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) @@ -69,7 +69,7 @@ test_that("repeatedly starting and stopping SparkR", { test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { - sparkR.session(enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) df <- createDataFrame(data.frame(dummy = 1:i)) expect_equal(count(df), i) sparkR.session.stop() @@ -79,12 +79,12 @@ test_that("repeatedly starting and stopping SparkSession", { test_that("rdd GC across sparkR.stop", { skip_on_cran() - sc <- sparkR.sparkContext() # sc should get id 0 + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 sparkR.session.stop() - sc <- sparkR.sparkContext() # sc should get id 0 again + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -104,7 +104,7 @@ test_that("rdd GC across sparkR.stop", { test_that("job group functions can be called", { skip_on_cran() - sc <- sparkR.sparkContext() + sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() @@ -118,7 +118,7 @@ test_that("job group functions can be called", { test_that("utility function can be called", { skip_on_cran() - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) @@ -175,7 +175,7 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() @@ -184,7 +184,7 @@ test_that("spark.lapply should perform simple transforms", { test_that("add and get file to be downloaded with Spark job on every node", { skip_on_cran() - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index f823ad8e9c98..d7d9eeed1575 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/inst/tests/testthat/test_jvm_api.R index 7348c893d0af..8b3b4f73de17 100644 --- a/R/pkg/inst/tests/testthat/test_jvm_api.R +++ b/R/pkg/inst/tests/testthat/test_jvm_api.R @@ -17,7 +17,7 @@ context("JVM API") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("Create and call methods on object", { jarr <- sparkR.newJObject("java.util.ArrayList") diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index cbc708718286..c1c746828d24 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib classification algorithms, except for tree-based algorithms") # Tests for MLlib classification algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -38,9 +38,8 @@ test_that("spark.svmLinear", { expect_true(class(summary$coefficients[, 1]) == "numeric") coefs <- summary$coefficients[, "Estimate"] - expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085) expect_true(all(abs(coefs - expected_coefs) < 0.1)) - expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) # Test prediction with string label prediction <- predict(model, training) @@ -50,15 +49,17 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # Test prediction with numeric label label <- c(0.0, 0.0, 0.0, 1.0, 1.0) @@ -128,15 +129,17 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # R code to reproduce the result. # nolint start @@ -243,19 +246,21 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$numOfInputs, 4) - expect_equal(summary2$numOfOutputs, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + } # Test default parameter model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) @@ -354,16 +359,18 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + } # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 478012e8828c..8f71de1cbc7b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib clustering algorithms") # Tests for MLlib clustering algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -53,18 +53,20 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } }) test_that("spark.gaussianMixture", { @@ -125,18 +127,20 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) + + unlink(modelPath) + } }) test_that("spark.kmeans", { @@ -171,18 +175,20 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } # Test Kmeans on dataset that is sensitive to seed value col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) @@ -236,22 +242,24 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_true(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) - expect_equal(logPrior, stats2$logPrior) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_true(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) + expect_equal(logPrior, stats2$logPrior) + + unlink(modelPath) + } }) test_that("spark.lda with text input", { diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R index c38f1133897d..4e10ca1e4f50 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib frequent pattern mining") # Tests for MLlib frequent pattern mining algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.fpGrowth", { data <- selectExpr(createDataFrame(data.frame(items = c( @@ -62,15 +62,17 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") - write.ml(model, modelPath, overwrite = TRUE) - loaded_model <- read.ml(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) - expect_equivalent( - itemsets, - collect(spark.freqItemsets(loaded_model))) + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) - unlink(modelPath) + unlink(modelPath) + } model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) expect_equal( diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R index 6b1040db9305..cc8064f88d27 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib recommendation algorithms") # Tests for MLlib recommendation algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -37,29 +37,31 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - unlink(modelPath) + unlink(modelPath) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 58924f952c6b..b05fdd350ca2 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib regression algorithms, except for tree-based algorithms") # Tests for MLlib regression algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { skip_on_cran() @@ -401,14 +401,16 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) + } }) test_that("spark.survreg", { @@ -450,17 +452,19 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + } # Test survival::survreg if (requireNamespace("survival", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/inst/tests/testthat/test_mllib_stat.R index beb148e7702f..1600833a5d03 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_stat.R +++ b/R/pkg/inst/tests/testthat/test_mllib_stat.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib statistics algorithms") # Tests for MLlib statistics algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.kstest", { data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e0802a9b02d1..5fd6a38ecb4f 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib tree-based algorithms") # Tests for MLlib tree-based algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -44,21 +44,23 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) - unlink(modelPath) + unlink(modelPath) + } # classification # label must be binary - GBTClassifier currently only supports binary classification. @@ -76,17 +78,19 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) - unlink(modelPath) + unlink(modelPath) + } iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) df <- suppressWarnings(createDataFrame(iris2)) @@ -136,7 +140,101 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + +test_that("spark.decisionTree", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) write.ml(model, modelPath, overwrite = TRUE) @@ -146,20 +244,17 @@ test_that("spark.randomForest", { expect_equal(stats$numFeatures, stats2$numFeatures) expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) # classification data <- suppressWarnings(createDataFrame(iris)) - model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) stats <- summary(model) expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) @@ -168,7 +263,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) write.ml(model, modelPath, overwrite = TRUE) @@ -190,11 +285,10 @@ test_that("spark.randomForest", { } iris$NumericSpecies <- lapply(iris$Species, labelToIndex) data <- suppressWarnings(createDataFrame(iris[-5])) - model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) stats <- summary(model) expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) # Test numeric prediction values @@ -202,10 +296,10 @@ test_that("spark.randomForest", { expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) - # spark.randomForest classification can work on libsvm data + # spark.decisionTree classification can work on libsvm data data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") - model <- spark.randomForest(data, label ~ features, "classification") + model <- spark.decisionTree(data, label ~ features, "classification") expect_equal(summary(model)$numFeatures, 4) }) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 1f7f387de08c..52d4c93ed959 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,7 +33,7 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index a3b1631e1d11..fb244e1d49e2 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index cedf4f100c6c..18320ea44b38 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 47cc34a6c5b7..9fc6e5dabecc 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -61,7 +61,11 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- sparkR.session() +sparkSession <- if (not_cran_or_windows_with_hadoop()) { + sparkR.session(master = sparkRTestMaster) + } else { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -96,6 +100,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { skip_on_cran() @@ -322,51 +330,53 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "NA,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - # default "header" is false, inferSchema to handle "year" as "int" - df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") - expect_equal(count(df), 4) - expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) - expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), - sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) - - # since "year" is "int", let's skip the NA values - withoutna <- na.omit(df, how = "any", cols = "year") - expect_equal(count(withoutna), 3) - - unlink(csvPath) - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "Empty,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") - expect_equal(count(df2), 4) - withoutna2 <- na.omit(df2, how = "any", cols = "year") - expect_equal(count(withoutna2), 3) - expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) - - # writing csv file - csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") - write.df(df2, path = csvPath2, "csv", header = "true") - df3 <- read.df(csvPath2, "csv", header = "true") - expect_equal(nrow(df3), nrow(df2)) - expect_equal(colnames(df3), colnames(df2)) - csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) - expect_equal(colnames(df3), colnames(csv)) - - unlink(csvPath) - unlink(csvPath2) + if (not_cran_or_windows_with_hadoop()) { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + # default "header" is false, inferSchema to handle "year" as "int" + df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expect_equal(count(df), 4) + expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) + expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), + sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) + + # since "year" is "int", let's skip the NA values + withoutna <- na.omit(df, how = "any", cols = "year") + expect_equal(count(withoutna), 3) + + unlink(csvPath) + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "Empty,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") + expect_equal(count(df2), 4) + withoutna2 <- na.omit(df2, how = "any", cols = "year") + expect_equal(count(withoutna2), 3) + expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + + unlink(csvPath) + unlink(csvPath2) + } }) test_that("Support other types for options", { @@ -597,48 +607,50 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - # Test read.df - df <- read.df(jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test read.df with a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(jsonPath, "json", schema) - expect_is(df1, "SparkDataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Test loadDF - df2 <- loadDF(jsonPath, "json", schema) - expect_is(df2, "SparkDataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) - - # Test read.json - df <- read.json(jsonPath) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test write.df - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") - write.df(df, jsonPath2, "json", mode = "overwrite") - - # Test write.json - jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") - write.json(df, jsonPath3) - - # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) - expect_is(jsonDF1, "SparkDataFrame") - expect_equal(count(jsonDF1), 6) - # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) - expect_is(jsonDF2, "SparkDataFrame") - expect_equal(count(jsonDF2), 6) - - unlink(jsonPath2) - unlink(jsonPath3) + if (not_cran_or_windows_with_hadoop()) { + # Test read.df + df <- read.df(jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(jsonPath, "json", schema) + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(jsonPath, "json", schema) + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(jsonPath) + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "SparkDataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "SparkDataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) + } }) test_that("read/write json files - compression option", { @@ -673,24 +685,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - expect_equal(length(tableNames("default")), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + tables <- listTables() - expect_equal(count(tables), 1) + expect_equal(count(tables), count + 1) expect_equal(count(tables()), count(tables)) expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) tables <- listTables() - expect_equal(count(tables), 2) + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) tables <- listTables() - expect_equal(count(tables), 0) + expect_equal(count(tables), count + 0) }) test_that( @@ -729,33 +744,35 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - df <- read.df(jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(parquetPath2, "parquet") - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql("select * from table1")), 5) - expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - expect_true(dropTempView("table1")) - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql("select * from table1")), 2) - expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - expect_true(dropTempView("table1")) - - unlink(jsonPath2) - unlink(parquetPath2) + if (not_cran_or_windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(parquetPath2, "parquet") + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + expect_true(dropTempView("table1")) + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + expect_true(dropTempView("table1")) + + unlink(jsonPath2) + unlink(parquetPath2) + } }) test_that("tableToDF() returns a new DataFrame", { @@ -947,14 +964,16 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - checkpointDir <- file.path(tempdir(), "cproot") - expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) - - setCheckpointDir(checkpointDir) - df <- read.json(jsonPath) - df <- checkpoint(df) - expect_is(df, "SparkDataFrame") - expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + if (not_cran_or_windows_with_hadoop()) { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + } }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { @@ -1223,6 +1242,16 @@ test_that("select with column", { expect_equal(columns(df4), c("name", "age")) expect_equal(count(df4), 3) + # Test select with alias + df5 <- alias(df, "table") + + expect_equal(columns(select(df5, column("table.name"))), "name") + expect_equal(columns(select(df5, "table.name")), "name") + + # Test that stats::alias is not masked + expect_is(alias(aov(yield ~ block + N * P * K, npk)), "listof") + + expect_error(select(df, c("name", "age"), "name"), "To select multiple columns, use a character vector or list for col") }) @@ -1312,45 +1341,47 @@ test_that("column calculation", { }) test_that("test HiveContext", { - setHiveContext(sc) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - createTable("people", source = "json", schema = schema) - df <- read.df(jsonPathNa, "json", schema) - insertInto(df, "people") - expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - sql("DROP TABLE people") - - df <- createTable("json", jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - df2 <- sql("select * from json") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "json2", "json", "append", path = jsonPath2) - df3 <- sql("select * from json2") - expect_is(df3, "SparkDataFrame") - expect_equal(count(df3), 3) - unlink(jsonPath2) - - hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "hivetestbl", path = hivetestDataPath) - df4 <- sql("select * from hivetestbl") - expect_is(df4, "SparkDataFrame") - expect_equal(count(df4), 3) - unlink(hivetestDataPath) - - parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) - df5 <- sql("select * from parquetest") - expect_is(df5, "SparkDataFrame") - expect_equal(count(df5), 3) - unlink(parquetDataPath) - - unsetHiveContext() + if (not_cran_or_windows_with_hadoop()) { + setHiveContext(sc) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + df2 <- sql("select * from json") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "json2", "json", "append", path = jsonPath2) + df3 <- sql("select * from json2") + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "hivetestbl", path = hivetestDataPath) + df4 <- sql("select * from hivetestbl") + expect_is(df4, "SparkDataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) + df5 <- sql("select * from parquetest") + expect_is(df5, "SparkDataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) + + unsetHiveContext() + } }) test_that("column operators", { @@ -2199,6 +2230,11 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) ) expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) + + execution_plan_broadcast <- capture.output( + explain(join(df1, broadcast(df2), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast))) }) test_that("toJSON() on DataFrame", { @@ -2398,34 +2434,36 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - df <- read.df(jsonPath, "json") - # Test write.df and read.df - write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(parquetPath, "parquet") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.parquet(df, parquetPath2) - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) - expect_is(parquetDF, "SparkDataFrame") - expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) - expect_is(parquetDF2, "SparkDataFrame") - expect_equal(count(parquetDF2), count(df) * 2) - - # Test if varargs works with variables - saveMode <- "overwrite" - mergeSchema <- "true" - parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) - - unlink(parquetPath2) - unlink(parquetPath3) - unlink(parquetPath4) + if (not_cran_or_windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode = "overwrite") + df2 <- read.df(parquetPath, "parquet") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "SparkDataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) + expect_is(parquetDF2, "SparkDataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) + } }) test_that("read/write Parquet files - compression option/mode", { @@ -3387,7 +3425,7 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() + skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 91df7ac6f984..b20b4312fbaa 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -21,7 +21,7 @@ context("Structured Streaming") # Tests for Structured Streaming functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") if (.Platform$OS.type == "windows") { diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index e2130eaac78d..c00723ba31f4 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,7 +30,7 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 28b7e8e3183f..e8a961cb3e87 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,7 +18,7 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 4a01e875405f..02691f0f6431 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,7 +18,7 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists @@ -136,7 +136,7 @@ test_that("cleanClosure on R functions", { # Test for broadcast variables. a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - aBroadcast <- broadcast(sc, a) + aBroadcast <- broadcastRDD(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) env <- environment(newnormMultiply) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 29812f872c78..9c6cba535d11 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -31,4 +31,9 @@ sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRTestMaster <- "local[1]" +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" +} + test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d38ec4f1b6f3..13a399165c8b 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,8 +46,9 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() +sparkR.session(master = "local[1]") ``` -```{r, message=FALSE, results="hide"} +```{r, eval=FALSE} sparkR.session() ``` @@ -65,7 +66,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -379,7 +380,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -405,7 +406,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -458,20 +459,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -780,7 +781,7 @@ head(predict(isoregModel, newDF)) `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `longley` dataset to train a gradient-boosted tree and make predictions: ```{r, warning=FALSE} df <- createDataFrame(longley) @@ -820,7 +821,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -851,9 +852,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -901,7 +902,7 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. ```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -981,7 +982,7 @@ testSummary ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -1079,19 +1080,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. diff --git a/appveyor.yml b/appveyor.yml index 4d31af70f056..58c2e98289e9 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,6 +48,9 @@ install: build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R @@ -56,4 +59,3 @@ notifications: on_build_success: false on_build_failure: false on_build_status_changed: false - diff --git a/assembly/pom.xml b/assembly/pom.xml index 742a4a1531e7..464af16e46f6 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -226,5 +226,19 @@ provided + + + + hadoop-cloud + + + org.apache.spark + spark-hadoop-cloud_${scala.binary.version} + ${project.version} + + + diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a8488d8d1b70..7e24315c5f39 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -87,6 +87,9 @@ *:* + + org.scala-lang:scala-library + @@ -98,7 +101,7 @@ - + com.fasterxml.jackson ${spark.shade.packageName}.com.fasterxml.jackson diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 4ab5b6889c21..aca6fca00c48 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -48,7 +48,7 @@ public final class Platform { boolean _unaligned; String arch = System.getProperty("os.arch", ""); if (arch.equals("ppc64le") || arch.equals("ppc64")) { - // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but // ppc64 and ppc64le support it _unaligned = true; } else { diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 55cb094b4af4..2ecb4f1464a4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -15,6 +15,6 @@ # limitations under the License. # -spark.mesos.executor.docker.image: +spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 94bd2c477a35..b7c985ace69c 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,7 +34,6 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index cb9922d23c44..d430d8c5fb35 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -26,7 +26,6 @@ function getThreadDumpEnabled() { } function formatStatus(status, type) { - if (type !== 'display') return status; if (status) { return "Active" } else { @@ -417,7 +416,6 @@ $(document).ready(function () { }, {data: 'hostPort'}, {data: 'isActive', render: function (data, type, row) { - if (type !== 'display') return data; if (row.isBlacklisted) return "Blacklisted"; else return formatStatus (data, type); } @@ -492,24 +490,20 @@ $(document).ready(function () { {data: 'totalInputBytes', render: formatBytes}, {data: 'totalShuffleRead', render: formatBytes}, {data: 'totalShuffleWrite', render: formatBytes}, - {data: 'executorLogs', render: formatLogsCells}, + {name: 'executorLogsCol', data: 'executorLogs', render: formatLogsCells}, { + name: 'threadDumpCol', data: 'id', render: function (data, type) { return type === 'display' ? ("Thread Dump" ) : data; } } ], - "columnDefs": [ - { - "targets": [ 16 ], - "visible": getThreadDumpEnabled() - } - ], "order": [[0, "asc"]] }; var dt = $(selector).DataTable(conf); - dt.column(15).visible(logsExist(response)); + dt.column('executorLogsCol:name').visible(logsExist(response)); + dt.column('threadDumpCol:name').visible(getThreadDumpEnabled()); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index ff241470f32d..9960d5c34d1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -207,8 +207,8 @@ sorttable = { hasInputs = (typeof node.getElementsByTagName == 'function') && node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { + + if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) { return node.getAttribute("sorttable_customkey"); } else if (typeof node.textContent != 'undefined' && !hasInputs) { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index a50600f1488c..089969398801 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -261,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S private def getImpl(timeout: Duration): T = { // This will throw TimeoutException on timeout: - Await.ready(futureAction, timeout) + ThreadUtils.awaitReady(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) case scala.util.Failure(exception) => diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2a2ce0504dbb..956724b14bba 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -579,7 +579,9 @@ private[spark] object SparkConf extends Logging { "are no longer accepted. To specify the equivalent now, one may use '64k'."), DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", - "Please use the new blacklisting options, spark.blacklist.*") + "Please use the new blacklisting options, spark.blacklist.*"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f4a59f069a5f..3196c1ece15e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -177,7 +177,7 @@ object SparkEnv extends Logging { SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, - port, + Option(port), isLocal, numCores, ioEncryptionKey, @@ -194,7 +194,6 @@ object SparkEnv extends Logging { conf: SparkConf, executorId: String, hostname: String, - port: Int, numCores: Int, ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { @@ -203,7 +202,7 @@ object SparkEnv extends Logging { executorId, hostname, hostname, - port, + None, isLocal, numCores, ioEncryptionKey @@ -220,7 +219,7 @@ object SparkEnv extends Logging { executorId: String, bindAddress: String, advertiseAddress: String, - port: Int, + port: Option[Int], isLocal: Boolean, numUsableCores: Int, ioEncryptionKey: Option[Array[Byte]], @@ -243,17 +242,12 @@ object SparkEnv extends Logging { } val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, securityManager, clientMode = !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - // In the non-driver case, the RPC env's address may be null since it may not be listening - // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else if (rpcEnv.address != null) { - conf.set("spark.executor.port", rpcEnv.address.port.toString) - logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}") } // Create an instance of the class with the given name, possibly initializing it with our conf diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 8cd1d1c96aa0..01d8973e1bb0 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl( /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(): Unit = synchronized { + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b0dd2fc187ba..fb0405b1a69c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -879,7 +879,7 @@ private[spark] class PythonAccumulatorV2( private val serverPort: Int) extends CollectionAccumulator[Array[Byte]] { - Utils.checkHost(serverHost, "Expected hostname") + Utils.checkHost(serverHost) val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8..c1a91c27eef2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable @@ -34,6 +34,16 @@ private[deploy] object DeployMessages { // Worker to Master + /** + * @param id the worker id + * @param host the worker host + * @param port the worker post + * @param worker the worker endpoint ref + * @param cores the core number of worker + * @param memory the memory size of worker + * @param workerWebUiUrl the worker Web UI address + * @param masterAddress the master address used by the worker to connect + */ case class RegisterWorker( id: String, host: String, @@ -41,9 +51,10 @@ private[deploy] object DeployMessages { worker: RpcEndpointRef, cores: Int, memory: Int, - workerWebUiUrl: String) + workerWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } @@ -80,8 +91,16 @@ private[deploy] object DeployMessages { sealed trait RegisterWorkerResponse - case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage - with RegisterWorkerResponse + /** + * @param master the master ref + * @param masterWebUiUrl the master Web UI address + * @param masterAddress the master address used by the worker to connect. It should be + * [[RegisterWorker.masterAddress]]. + */ + case class RegisteredWorker( + master: RpcEndpointRef, + masterWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse @@ -131,7 +150,7 @@ private[deploy] object DeployMessages { // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") + Utils.checkHostPort(hostPort) } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -183,7 +202,7 @@ private[deploy] object DeployMessages { completedDrivers: Array[DriverInfo], status: MasterState) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) def uri: String = "spark://" + host + ":" + port @@ -201,7 +220,7 @@ private[deploy] object DeployMessages { drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 050778a895c0..7d356e8fc1c0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -92,6 +92,9 @@ private[deploy] object RPackageUtils extends Logging { * Exposed for testing. */ private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + if (jar.getManifest == null) { + return false + } val manifest = jar.getManifest.getMainAttributes manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 6d8758a3d3b1..5cb48ca3e60b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -30,7 +30,8 @@ private[spark] case class ApplicationAttemptInfo( endTime: Long, lastUpdated: Long, sparkUser: String, - completed: Boolean = false) + completed: Boolean = false, + appSparkVersion: String) private[spark] case class ApplicationHistoryInfo( id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f4235df24512..d05ca142b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -248,7 +248,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, - HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) + HistoryServer.getAttemptURI(appId, attempt.attemptId), + attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } @@ -257,6 +258,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) if (appListener.appId.isDefined) { + ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) // make sure to set admin acls before view acls so they are properly picked up val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") @@ -443,7 +445,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || - eventString.startsWith(APPL_END_EVENT_PREFIX) + eventString.startsWith(APPL_END_EVENT_PREFIX) || + eventString.startsWith(LOG_START_EVENT_PREFIX) } val logPath = fileStatus.getPath() @@ -469,7 +472,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) lastUpdated, appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted, - fileStatus.getLen() + fileStatus.getLen(), + appListener.appSparkVersion.getOrElse("") ) fileToAppInfo(logPath) = attemptInfo logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") @@ -735,6 +739,8 @@ private[history] object FsHistoryProvider { private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" + + private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" } /** @@ -762,9 +768,10 @@ private class FsApplicationAttemptInfo( lastUpdated: Long, sparkUser: String, completed: Boolean, - val fileSize: Long) + val fileSize: Long, + appSparkVersion: String) extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed) { + attemptId, startTime, endTime, lastUpdated, sparkUser, completed, appSparkVersion) { /** extend the superclass string value with the extra attributes of this class */ override def toString: String = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0e7a6c24d4fa..af1471763340 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + // stripXSS is called first to remove suspicious characters used in XSS attacks val requestedIncomplete = - Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean + Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 816bf37e39fe..933209048cce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -80,7 +80,7 @@ private[deploy] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(address.host, "Expected hostname") + Utils.checkHost(address.host) private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, @@ -231,7 +231,8 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -243,7 +244,7 @@ private[deploy] class Master( workerRef, workerWebUiUrl) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress)) schedule() } else { val workerAddress = worker.endpoint.address @@ -782,6 +783,10 @@ private[deploy] class Master( exec.state = ExecutorState.LOST exec.application.removeExecutor(exec) } + val failedApps = apps.find(p => p.driver.address.host == worker.endpoint.address.host) + for (app <- failedApps) { + finishApplication(app) + } for (driver <- worker.drivers.values) { if (driver.desc.supervise) { logInfo(s"Re-launching ${driver.id}") @@ -795,9 +800,12 @@ private[deploy] class Master( } private def relaunchDriver(driver: DriverInfo) { - driver.worker = None - driver.state = DriverState.RELAUNCHING - waitingDrivers += driver + removeDriver(driver.id, DriverState.RELAUNCHING, None) + val newDriver = createDriver(driver.desc) + persistenceEngine.addDriver(newDriver) + drivers.add(newDriver) + waitingDrivers += newDriver + schedule() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c63793c16dce..615d2533cf08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -60,12 +60,12 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 4e20c10fd142..c87d6e24b78c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -32,7 +32,7 @@ private[spark] class WorkerInfo( val webUiAddress: String) extends Serializable { - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index a8d721f3e0d4..f40896457df9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val appId = UIUtils.stripXSS(request.getParameter("appId")) val state = master.askSync[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) @@ -99,11 +100,11 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
-

Executor Summary

+

Executor Summary ({allExecutors.length})

{executorsTable} { if (removedExecutors.nonEmpty) { -

Removed Executors

++ +

Removed Executors ({removedExecutors.length})

++ removedExecutorsTable } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 9351c72094e3..bc0bf6a1d970 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { if (parent.killEnabled && parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val id = Option(request.getParameter("id")) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val killFlag = + Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean + val id = Option(UIUtils.stripXSS(request.getParameter("id"))) if (id.isDefined && killFlag) { action(id.get) } @@ -126,14 +128,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Workers

+

Workers ({workers.length})

{workerTable}
-

Running Applications

+

Running Applications ({activeApps.length})

{activeAppsTable}
@@ -142,7 +144,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
-

Running Drivers

+

Running Drivers ({activeDrivers.length})

{activeDriversTable}
@@ -152,7 +154,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Completed Applications

+

Completed Applications ({completedApps.length})

{completedAppsTable}
@@ -162,7 +164,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
-

Completed Drivers

+

Completed Drivers ({completedDrivers.length})

{completedDriversTable}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 00b9d1af373d..1198e3cb05ea 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -55,7 +55,7 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) // A scheduled executor used to send messages at the specified time. @@ -99,6 +99,20 @@ private[deploy] class Worker( private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None + + /** + * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker + * will just use the address received from Master. + */ + private val preferConfiguredMasterAddress = + conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false) + /** + * The master address to connect in case of failure. When the connection is broken, worker will + * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when + * a master is restarted or takes over leadership, it will be an address sent from master, which + * may not be in `masterRpcAddresses`. + */ + private var masterAddressToConnect: Option[RpcAddress] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" private var workerWebUiUrl: String = "" @@ -196,10 +210,19 @@ private[deploy] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { + /** + * Change to use the new master. + * + * @param masterRef the new master ref + * @param uiUrl the new master Web UI address + * @param masterAddress the new master address which the worker should use to connect in case of + * failure + */ + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) { // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl + masterAddressToConnect = Some(masterAddress) master = Some(masterRef) connected = true if (conf.getBoolean("spark.ui.reverseProxy", false)) { @@ -266,7 +289,8 @@ private[deploy] class Worker( if (registerMasterFutures != null) { registerMasterFutures.foreach(_.cancel(true)) } - val masterAddress = masterRef.address + val masterAddress = + if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { @@ -342,15 +366,27 @@ private[deploy] class Worker( } private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { - masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl)) + masterEndpoint.send(RegisterWorker( + workerId, + host, + port, + self, + cores, + memory, + workerWebUiUrl, + masterEndpoint.address)) } private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { msg match { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) => + if (preferConfiguredMasterAddress) { + logInfo("Successfully registered with master " + masterAddress.toSparkURL) + } else { + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + } registered = true - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterAddress) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { self.send(SendHeartbeat) @@ -419,7 +455,7 @@ private[deploy] class Worker( case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterRef.address) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) @@ -561,7 +597,8 @@ private[deploy] class Worker( } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (master.exists(_.address == remoteAddress)) { + if (master.exists(_.address == remoteAddress) || + masterAddressToConnect.exists(_ == remoteAddress)) { logInfo(s"$remoteAddress Disassociated !") masterDisconnected() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 777020d4d5c8..bd07d342e04a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -68,12 +68,12 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 80dc9bf8779d..2f5a5642d3ca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -33,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val supportedLogTypes = Set("stderr", "stdout") private val defaultBytes = 100 * 1024 + // stripXSS is called first to remove suspicious characters used in XSS attacks def renderLog(request: HttpServletRequest): String = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => @@ -55,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with pre + logText } + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b2b26ee107c0..a2f1aa22b006 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -191,11 +191,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf - val port = executorConf.getInt("spark.executor.port", 0) val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, - port, + -1, executorConf, new SecurityManager(executorConf), clientMode = true) @@ -221,7 +220,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) + driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 51b6c373c4da..5b396687dd11 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -71,7 +71,7 @@ private[spark] class Executor( private val conf = env.conf // No ip or host:port - just hostname - Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname) // must not have port specified. assert (0 == Utils.parseHostPort(executorHostname)._2) @@ -425,6 +425,7 @@ private[spark] class Executor( } } + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7f7921d56f49..e193ed222e22 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -278,4 +278,13 @@ package object config { "spark.io.compression.codec.") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = + ConfigBuilder("spark.shuffle.accurateBlockThreshold") + .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + + "record the size accurately if it's above this config. This helps to prevent OOM by " + + "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(100 * 1024 * 1024) + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 28c45d800ed0..6da8865cd10d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -34,6 +34,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var adminAcls: Option[String] = None var viewAclsGroups: Option[String] = None var adminAclsGroups: Option[String] = None + var appSparkVersion: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) @@ -57,4 +58,10 @@ private[spark] class ApplicationEventListener extends SparkListener { adminAclsGroups = allProperties.get("spark.admin.acls.groups") } } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + appSparkVersion = Some(sparkVersion) + case _ => + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aab177f257a8..875acc37e90f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ * set of map output files, and another to read those files after a barrier). In the end, every * stage will have only shuffle dependencies on other stages, and may compute multiple operations * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of - * various RDDs (MappedRDD, FilteredRDD, etc). + * various RDDs * * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the @@ -618,12 +618,7 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`, - // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that - // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's - // safe to pass in null here. For more detail, see SPARK-13747. - val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - waiter.completionFuture.ready(Duration.Inf)(awaitPermission) + ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) waiter.completionFuture.value.get match { case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a7dbf87915b2..f48143633224 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -119,7 +119,7 @@ private[spark] class EventLoggingListener( val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) val bstream = new BufferedOutputStream(cstream, outputBufferSize) - EventLoggingListener.initEventLog(bstream) + EventLoggingListener.initEventLog(bstream, testing, loggedEvents) fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) writer = Some(new PrintWriter(bstream)) logInfo("Logging events to %s".format(logPath)) @@ -283,10 +283,17 @@ private[spark] object EventLoggingListener extends Logging { * * @param logStream Raw output stream to the event log file. */ - def initEventLog(logStream: OutputStream): Unit = { + def initEventLog( + logStream: OutputStream, + testing: Boolean, + loggedEvents: ArrayBuffer[JValue]): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + val eventJson = JsonProtocol.logStartToJson(metadata) + val metadataJson = compact(eventJson) + "\n" logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) + if (testing && loggedEvents != null) { + loggedEvents += eventJson + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index b2e9a97129f0..048e0d018659 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,8 +19,13 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.roaringbitmap.RoaringBitmap +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus( } /** - * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, + * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger + * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks, * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty - * @param avgSize average size of the non-empty blocks + * @param avgSize average size of the non-empty and non-huge blocks + * @param hugeBlockSizes sizes of huge blocks by their reduceId. */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, - private[this] var avgSize: Long) + private[this] var avgSize: Long, + @transient private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization - require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0, + require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { + assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { 0 } else { - avgSize + hugeBlockSizes.get(reduceId) match { + case Some(size) => MapStatus.decompressSize(size) + case None => avgSize + } } } @@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private ( loc.writeExternal(out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) + out.writeInt(hugeBlockSizes.size) + hugeBlockSizes.foreach { kv => + out.writeInt(kv._1) + out.writeByte(kv._2) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private ( emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() + val count = in.readInt() + val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + (0 until count).foreach { _ => + val block = in.readInt() + val size = in.readByte() + hugeBlockSizesArray += Tuple2(block, size) + } + hugeBlockSizes = hugeBlockSizesArray.toMap } } @@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus { // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() while (i < totalNumBlocks) { - var size = uncompressedSizes(i) + val size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 - totalSize += size + // Huge blocks are not included in the calculation for average size, thus size for smaller + // blocks is more accurate. + if (size < threshold) { + totalSize += size + } else { + hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) + } } else { emptyBlocks.add(i) } @@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) + new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, + hugeBlockSizesArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index bc2e53071668..59f89a82a1da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -160,9 +160,9 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent /** * An internal class that describes the metadata of an event log. - * This event is not meant to be posted to listeners downstream. */ -private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * Interface for creating history listeners defined in other modules like SQL, which are used to diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 3ff363321e8c..3b0d3b1b150f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -71,7 +71,6 @@ private[spark] trait SparkListenerBus listener.onNodeUnblacklisted(nodeUnblacklisted) case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) - case logStart: SparkListenerLogStart => // ignore event log metadata case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c337b992c84..7767ef1803a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -115,26 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - // Though we unset the ThreadLocal here, the context member variable itself is still queried - // directly in the TaskRunner to check for FetchFailedExceptions. - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index a0239266d875..f039744e7f67 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -90,7 +90,8 @@ private[spark] object ApplicationsListResource { }, lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, - completed = internalAttemptInfo.completed + completed = internalAttemptInfo.completed, + appSparkVersion = internalAttemptInfo.appSparkVersion ) } ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 56d8e51732ff..f6203271f3cd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -38,7 +38,8 @@ class ApplicationAttemptInfo private[spark]( val lastUpdated: Date, val duration: Long, val sparkUser: String, - val completed: Boolean = false) { + val completed: Boolean = false, + val appSparkVersion: String) { def getStartTimeEpoch: Long = startTime.getTime def getEndTimeEpoch: Long = endTime.getTime def getLastUpdatedEpoch: Long = lastUpdated.getTime diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3219969bcd06..137d24b52515 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -23,14 +23,12 @@ import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import com.google.common.io.ByteStreams - import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging @@ -41,7 +39,6 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -337,7 +334,7 @@ private[spark] class BlockManager( val task = asyncReregisterTask if (task != null) { try { - Await.ready(task, Duration.Inf) + ThreadUtils.awaitReady(task, Duration.Inf) } catch { case NonFatal(t) => throw new Exception("Error occurred while waiting for async. reregistration", t) @@ -612,12 +609,19 @@ private[spark] class BlockManager( /** * Return a list of locations for the given block, prioritizing the local machine since - * multiple block managers can share the same host. + * multiple block managers can share the same host, followed by hosts on the same rack. */ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { val locs = Random.shuffle(master.getLocations(blockId)) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - preferredLocs ++ otherLocs + blockManagerId.topologyInfo match { + case None => preferredLocs ++ otherLocs + case Some(_) => + val (sameRackLocs, differentRackLocs) = otherLocs.partition { + loc => blockManagerId.topologyInfo == loc.topologyInfo + } + preferredLocs ++ sameRackLocs ++ differentRackLocs + } } /** @@ -912,7 +916,7 @@ private[spark] class BlockManager( if (level.replication > 1) { // Wait for asynchronous replication to finish try { - Await.ready(replicationFuture, Duration.Inf) + ThreadUtils.awaitReady(replicationFuture, Duration.Inf) } catch { case NonFatal(t) => throw new Exception("Error occurred while waiting for replication to finish", t) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c37a3604d28f..2c3da0ee85e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -46,7 +46,7 @@ class BlockManagerId private ( def executorId: String = executorId_ if (null != host_) { - Utils.checkHost(host_, "Expected hostname") + Utils.checkHost(host_) assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bf4cf79e9faa..f271c56021e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,6 +60,8 @@ private[spark] class SparkUI private ( var appId: String = _ + var appSparkVersion = org.apache.spark.SPARK_VERSION + private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -118,7 +120,8 @@ private[spark] class SparkUI private ( duration = 0, lastUpdated = new Date(startTime), sparkUser = getSparkUser, - completed = false + completed = false, + appSparkVersion = appSparkVersion )) )) } @@ -139,6 +142,7 @@ private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) def appName: String = parent.appName + def appSparkVersion: String = parent.appSparkVersion } private[spark] object SparkUI { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 79b0d81af52b..2610f673d27f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -34,6 +36,8 @@ private[spark] object UIUtils extends Logging { val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" + private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = @@ -228,7 +232,7 @@ private[spark] object UIUtils extends Logging {